diff --git a/.travis.yml b/.travis.yml index 27e37fdb9..6726b5afd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,10 @@ language: python -python: - - "3.6" + +matrix: + include: + - python: 3.7 + dist: xenial + sudo: true sudo: required services: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 16579b981..9d8cad8d6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,17 @@ All notable changes to this project are documented in this file. - Improve Connection Failure Handling in NodeLeader `#915 `_ - Improve transaction coverage and fix `PublishTransaction.Size()` `#929 `_ - Align ``Fixed8`` ``ToString()`` output with C# `#941 `_ +- Move from using ``Twisted`` to ``asyncio`` as event driven framework `#934 `_ +- Add new networking code +- Add IP Filtering +- Tighten smart contract storage context validation +- Update VM to 2.4.3 (Also updates ``ApplicationEngine``, ``StateReader`` and ``StateMachine``) +- Add support for new Ping/Pong network payload +- Add neo-vm JSON test support + add new Debugger wrapping class +- Various VM performance updates +- Various code cleaning updates +- Ensure LevelDB iterators are close, ensure all ``MemoryStream`` usages go through ``StreamManger`` and are closed. + [0.8.4] 2019-02-14 diff --git a/README.rst b/README.rst index e703184a8..b46399f56 100644 --- a/README.rst +++ b/README.rst @@ -27,7 +27,7 @@ Overview What does it currently do ~~~~~~~~~~~~~~~~~~~~~~~~~ -- This project aims to be a full port of the original C# `NEO +- This project aims to be an alternative implementation for the original C# `NEO project `_ - Run a Python based P2P node - Interactive CLI for configuring node and inspecting blockchain @@ -64,7 +64,7 @@ Get help or give help - Open a new `issue `_ if you encounter a problem. -- Or ping **@localhuman**, **@metachris** or **@ixje** on the `NEO +- Or ping **@ixje** on the `NEO Discord `_. - Pull requests welcome. Have a look at the issue list for ideas. You can help with wallet functionality, writing tests or documentation, @@ -77,8 +77,7 @@ neo-python has two System dependencies (everything else is covered with ``pip``): - `LevelDB `_ -- `Python - 3.6 `_ or `Python 3.7 `_ (3.5 and below is not supported) +- `Python 3.7 `_ (3.6 and below is not supported) We have published a Youtube `video `_ to help get you @@ -110,63 +109,19 @@ OSX brew install leveldb -Ubuntu/Debian 16.10+ +Ubuntu/Debian 18.04+ ^^^^^^^^^^^^^^^^^^^^ -Ubuntu starting at 16.10 supports Python 3.6+ in the official repositories. - -First, ensure Ubuntu is fully up-to-date with this: - -:: - - sudo apt-get update && sudo apt-get upgrade - -You can install Python 3.7 and all the system dependencies like this: - -:: - - sudo apt-get install python3.7 python3.7-dev python3.7-venv python3-pip libleveldb-dev libssl-dev g++ - - -Or, you can install Python 3.6 and all the system dependencies like this: - -:: - - sudo apt-get install python3.6 python3.6-dev python3.6-venv python3-pip libleveldb-dev libssl-dev g++ - -Older Ubuntu versions (eg. 16.04) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For older Ubuntu versions you'll need to use an external repository like -Felix Krull's deadsnakes PPA at -https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa (read more -`here `__). - -(The use of the third-party software links in this documentation is done -at your own discretion and risk and with agreement that you will be -solely responsible for any damage to your computer system or loss of -data that results from such activities.) - -:: - - apt-get install software-properties-common python-software-properties - add-apt-repository ppa:deadsnakes/ppa - apt-get update - apt-get install python3.6 python3.6-dev python3.6-venv python3-pip libleveldb-dev libssl-dev g++ +At the time of writing this message Ubuntu's package manager only lists Python `3.7.0`. Variously memory leaks have been addressed +since and it is recommended to run Python `3.7.3` or newer. Read `this Ubuntu guide `_ on how to add an alternative PPA to your sources list to install from +or how to compile the latest version manually. Centos/Redhat/Fedora ^^^^^^^^^^^^^^^^^^^^ -:: - - # Install Python 3.6: - yum install -y centos-release-scl - yum install -y rh-python36 - scl enable rh-python36 bash +Note: - Not tested - - # Install dependencies: - yum install -y epel-release - yum install -y readline-devel leveldb-devel libffi-devel gcc-c++ redhat-rpm-config gcc python-devel openssl-devel +Please correct the REAMDE if there are issues. Install Python `3.7.3` following `this CentOS guide `_. Windows ^^^^^^^ @@ -182,14 +137,13 @@ Help needed for running natively. Installing the Python package plyvel seems to compiler support tied to Visual Studio and libraries. Refer to `documentation `__. -Python 3.6+ +Python 3.7+ ~~~~~~~~~~~ -neo-python is compatible with **Python 3.6 and later**. +neo-python is compatible with **Python 3.7 and later**. -On \*nix systems, install Python 3.6 or Python 3.7 via your package manager, or -download an installation package from the `official -homepage `__. +On \*nix systems, install Python Python 3.7 via your package manager, or +download an installation package from the `official homepage `__. Install @@ -214,10 +168,6 @@ could lead to version conflicts. python3.7 -m venv venv source venv/bin/activate - # create virtual environment using Python 3.6 and activate - python3.6 -m venv venv - source venv/bin/activate - # install the package in an editable form (venv) pip install wheel -e . @@ -233,10 +183,6 @@ could lead to version conflicts. python3.7 -m venv venv source venv/bin/activate - # create virtual environment using Python 3.6 and activate - python3.6 -m venv venv - source venv/bin/activate - (venv) pip install wheel neo-python @@ -473,7 +419,7 @@ Troubleshooting If you run into problems, check these things before ripping out your hair: -- Double-check that you are using Python 3.6.x or Python 3.7.x +- Double-check that you are using Python 3.7.x - Update the project dependencies (``pip install -e .``) - If you encounter any problems, please take a look at the `installation diff --git a/docs/source/install.rst b/docs/source/install.rst index 058fd1d8e..d1a5f3e2c 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -139,6 +139,24 @@ The solution probably is brew reinstall openssl +----- + +If you encounter an issue installing the ``scrypt`` module (possibly after updating OSX) with an error like this: + +.. code-block:: sh + + ld: library not found for -lcrypto + clang: error: linker command failed with exit code 1 (use -v to see invocation) + error: command 'gcc' failed with exit status 1 + +The solution probably is + +.. code-block:: sh + + $ brew install openssl + $ export CFLAGS="-I$(brew --prefix openssl)/include $CFLAGS" + $ export LDFLAGS="-L$(brew --prefix openssl)/lib $LDFLAGS" + Install from PyPi ================= diff --git a/examples/node.py b/examples/node.py index efbaec0d2..71f2d5482 100644 --- a/examples/node.py +++ b/examples/node.py @@ -1,18 +1,16 @@ """ -Minimal NEO node with custom code in a background thread. +Minimal NEO node with custom code in a background task. It will log events from all smart contracts on the blockchain as they are seen in the received blocks. """ -import threading -from time import sleep +import asyncio from logzero import logger -from twisted.internet import reactor, task -from neo.Network.NodeLeader import NodeLeader from neo.Core.Blockchain import Blockchain from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain +from neo.Network.p2pservice import NetworkService from neo.Settings import settings @@ -21,16 +19,11 @@ # settings.set_logfile("/tmp/logfile.log", max_bytes=1e7, backup_count=3) -def custom_background_code(): - """ Custom code run in a background thread. - - This function is run in a daemonized thread, which means it can be instantly killed at any - moment, whenever the main thread quits. If you need more safety, don't use a daemonized - thread and handle exiting this thread in another way (eg. with signals and events). - """ +async def custom_background_code(): + """ Custom code run in the background.""" while True: logger.info("Block %s / %s", str(Blockchain.Default().Height), str(Blockchain.Default().HeaderHeight)) - sleep(15) + await asyncio.sleep(15) def main(): @@ -40,18 +33,17 @@ def main(): # Setup the blockchain blockchain = LevelDBBlockchain(settings.chain_leveldb_path) Blockchain.RegisterBlockchain(blockchain) - dbloop = task.LoopingCall(Blockchain.Default().PersistBlocks) - dbloop.start(.1) - NodeLeader.Instance().Start() - - # Start a thread with custom code - d = threading.Thread(target=custom_background_code) - d.setDaemon(True) # daemonizing the thread will kill it when the main thread is quit - d.start() - - # Run all the things (blocking call) - reactor.run() - logger.info("Shutting down.") + + loop = asyncio.get_event_loop() + # Start a reoccurring task with custom code + loop.create_task(custom_background_code()) + p2p = NetworkService() + loop.create_task(p2p.start()) + + # block from here on + loop.run_forever() + + # have a look at the other examples for handling graceful shutdown. if __name__ == "__main__": diff --git a/examples/smart-contract-rest-api.py b/examples/smart-contract-rest-api.py index 4910c13d9..82579e1fe 100644 --- a/examples/smart-contract-rest-api.py +++ b/examples/smart-contract-rest-api.py @@ -6,8 +6,7 @@ Execution.Success and several more. See the documentation here: http://neo-python.readthedocs.io/en/latest/smartcontracts.html -This example requires the environment variable NEO_REST_API_TOKEN, and can -optionally use NEO_REST_LOGFILE and NEO_REST_API_PORT. +This example optionally uses the environment variables NEO_REST_LOGFILE and NEO_REST_API_PORT. Example usage (with "123" as valid API token): @@ -18,29 +17,21 @@ $ curl localhost:8080 $ curl -H "Authorization: Bearer 123" localhost:8080/echo/hello123 $ curl -X POST -H "Authorization: Bearer 123" -d '{ "hello": "world" }' localhost:8080/echo-post - -The REST API is using the Python package 'klein', which makes it possible to -create HTTP routes and handlers with Twisted in a similar style to Flask: -https://github.com/twisted/klein """ +import asyncio import os -import threading -import json -from time import sleep +from contextlib import suppress +from signal import SIGINT +from aiohttp import web from logzero import logger -from twisted.internet import reactor, task, endpoints -from twisted.web.server import Request, Site -from klein import Klein, resource -from neo.Network.NodeLeader import NodeLeader from neo.Core.Blockchain import Blockchain from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain +from neo.Network.p2pservice import NetworkService from neo.Settings import settings - -from neo.Network.api.decorators import json_response, gen_authenticated_decorator, catch_exceptions -from neo.contrib.smartcontract import SmartContract from neo.SmartContract.ContractParameter import ContractParameter, ContractParameterType +from neo.contrib.smartcontract import SmartContract # Set the hash of your contract here: SMART_CONTRACT_HASH = "6537b4bd100e514119e3a7ab49d520d20ef2c2a4" @@ -56,19 +47,9 @@ if LOGFILE: settings.set_logfile(LOGFILE, max_bytes=1e7, backup_count=3) -# Internal: get the API token from an environment variable -API_AUTH_TOKEN = os.getenv("NEO_REST_API_TOKEN", None) -if not API_AUTH_TOKEN: - raise Exception("No NEO_REST_API_TOKEN environment variable found!") - # Internal: setup the smart contract instance smart_contract = SmartContract(SMART_CONTRACT_HASH) -# Internal: setup the klein instance -app = Klein() - -# Internal: generate the @authenticated decorator with valid tokens -authenticated = gen_authenticated_decorator(API_AUTH_TOKEN) # # Smart contract event handler for Runtime.Notify events @@ -92,7 +73,7 @@ def sc_notify(event): # # Custom code that runs in the background # -def custom_background_code(): +async def custom_background_code(): """ Custom code run in a background thread. Prints the current block height. This function is run in a daemonized thread, which means it can be instantly killed at any @@ -101,78 +82,108 @@ def custom_background_code(): """ while True: logger.info("Block %s / %s", str(Blockchain.Default().Height), str(Blockchain.Default().HeaderHeight)) - sleep(15) + await asyncio.sleep(15) # # REST API Routes # -@app.route('/') -def home(request): - return "Hello world" - - -@app.route('/echo/') -@catch_exceptions -@authenticated -@json_response -def echo_msg(request, msg): - return { - "echo": msg +async def home_route(request): + return web.Response(body="hello world") + + +async def echo_msg(request): + res = { + "echo": request.match_info['msg'] } + return web.json_response(data=res) -@app.route('/echo-post', methods=['POST']) -@catch_exceptions -@authenticated -@json_response -def echo_post(request): +async def echo_post(request): # Parse POST JSON body - body = json.loads(request.content.read().decode("utf-8")) + + body = await request.json() # Echo it - return { + res = { "post-body": body } + return web.json_response(data=res) + # -# Main method which starts everything up +# Main setup method # - -def main(): +async def setup_and_start(loop): # Use TestNet - settings.setup_testnet() + settings.setup_privnet() # Setup the blockchain blockchain = LevelDBBlockchain(settings.chain_leveldb_path) Blockchain.RegisterBlockchain(blockchain) - dbloop = task.LoopingCall(Blockchain.Default().PersistBlocks) - dbloop.start(.1) - NodeLeader.Instance().Start() + + p2p = NetworkService() + loop.create_task(p2p.start()) + bg_task = loop.create_task(custom_background_code()) # Disable smart contract events for external smart contracts settings.set_log_smart_contract_events(False) - # Start a thread with custom code - d = threading.Thread(target=custom_background_code) - d.setDaemon(True) # daemonizing the thread will kill it when the main thread is quit - d.start() - - # Hook up Klein API to Twisted reactor. - endpoint_description = "tcp:port=%s:interface=localhost" % API_PORT + app = web.Application() + app.add_routes([ + web.route('*', '/', home_route), + web.get("/echo-get/{msg}", echo_msg), + web.post("/echo-post/", echo_post), + ]) - # If you want to make this service externally available (not only at localhost), - # then remove the `interface=localhost` part: - # endpoint_description = "tcp:port=%s" % API_PORT - - endpoint = endpoints.serverFromString(reactor, endpoint_description) - endpoint.listen(Site(app.resource())) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", API_PORT) + await site.start() # Run all the things (blocking call) logger.info("Everything setup and running. Waiting for events...") - reactor.run() - logger.info("Shutting down.") + return site + + +async def shutdown(): + # cleanup any remaining tasks + for task in asyncio.Task.all_tasks(): + with suppress((asyncio.CancelledError, Exception)): + task.cancel() + await task + + +def system_exit(): + raise SystemExit + + +def main(): + loop = asyncio.get_event_loop() + + # because a KeyboardInterrupt is so violent it can shutdown the DB in an unpredictable state. + loop.add_signal_handler(SIGINT, system_exit) + + main_task = loop.create_task(setup_and_start(loop)) + + try: + loop.run_forever() + except SystemExit: + logger.info("Shutting down...") + site = main_task.result() + loop.run_until_complete(site.stop()) + + p2p = NetworkService() + loop.run_until_complete(p2p.shutdown()) + + loop.run_until_complete(shutdown()) + loop.stop() + finally: + loop.close() + + logger.info("Closing databases...") + Blockchain.Default().Dispose() if __name__ == "__main__": diff --git a/examples/smart-contract.py b/examples/smart-contract.py index fdc067e45..2535fcc36 100644 --- a/examples/smart-contract.py +++ b/examples/smart-contract.py @@ -7,19 +7,18 @@ http://neo-python.readthedocs.io/en/latest/smartcontracts.html """ -import threading -from time import sleep +import asyncio +from contextlib import suppress +from signal import SIGINT from logzero import logger -from twisted.internet import reactor, task -from neo.contrib.smartcontract import SmartContract -from neo.SmartContract.ContractParameter import ContractParameter, ContractParameterType -from neo.Network.NodeLeader import NodeLeader from neo.Core.Blockchain import Blockchain from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain +from neo.Network.p2pservice import NetworkService from neo.Settings import settings - +from neo.SmartContract.ContractParameter import ContractParameter, ContractParameterType +from neo.contrib.smartcontract import SmartContract # If you want the log messages to also be saved in a logfile, enable the # next line. This configures a logfile with max 10 MB and 3 rotations: @@ -44,41 +43,64 @@ def sc_notify(event): logger.info("- payload part 1: %s", event.event_payload.Value[0].Value.decode("utf-8")) -def custom_background_code(): - """ Custom code run in a background thread. Prints the current block height. - - This function is run in a daemonized thread, which means it can be instantly killed at any - moment, whenever the main thread quits. If you need more safety, don't use a daemonized - thread and handle exiting this thread in another way (eg. with signals and events). - """ +async def custom_background_code(): + """ Custom code run in a background thread. Prints the current block height.""" while True: logger.info("Block %s / %s", str(Blockchain.Default().Height), str(Blockchain.Default().HeaderHeight)) - sleep(15) + await asyncio.sleep(15) -def main(): +async def setup_and_start(loop): # Use TestNet settings.setup_testnet() # Setup the blockchain blockchain = LevelDBBlockchain(settings.chain_leveldb_path) Blockchain.RegisterBlockchain(blockchain) - dbloop = task.LoopingCall(Blockchain.Default().PersistBlocks) - dbloop.start(.1) - NodeLeader.Instance().Start() + + p2p = NetworkService() + loop.create_task(p2p.start()) + bg_task = loop.create_task(custom_background_code()) # Disable smart contract events for external smart contracts settings.set_log_smart_contract_events(False) - # Start a thread with custom code - d = threading.Thread(target=custom_background_code) - d.setDaemon(True) # daemonizing the thread will kill it when the main thread is quit - d.start() - # Run all the things (blocking call) logger.info("Everything setup and running. Waiting for events...") - reactor.run() - logger.info("Shutting down.") + return bg_task + + +async def shutdown(): + # cleanup any remaining tasks + for task in asyncio.Task.all_tasks(): + with suppress(asyncio.CancelledError): + task.cancel() + await task + + +def system_exit(): + raise SystemExit + + +def main(): + loop = asyncio.get_event_loop() + + # because a KeyboardInterrupt is so violent it can shutdown the DB in an unpredictable state. + loop.add_signal_handler(SIGINT, system_exit) + main_task = loop.create_task(setup_and_start(loop)) + + try: + loop.run_forever() + except SystemExit: + logger.info("Shutting down...") + p2p = NetworkService() + loop.run_until_complete(p2p.shutdown()) + loop.run_until_complete(shutdown()) + loop.stop() + finally: + loop.close() + + Blockchain.Default().Dispose() if __name__ == "__main__": diff --git a/neo/Blockchain.py b/neo/Blockchain.py index 62203ab92..befede409 100644 --- a/neo/Blockchain.py +++ b/neo/Blockchain.py @@ -1,5 +1,3 @@ - - def GetBlockchain(): from neo.Core.Blockchain import Blockchain return Blockchain.Default() @@ -20,9 +18,25 @@ def GetSystemShare(): return Blockchain.SystemShare() -def GetStateReader(): - from neo.SmartContract.StateReader import StateReader - return StateReader() +def GetStateMachine(): + from neo.SmartContract.StateMachine import StateMachine + from neo.Implementations.Blockchains.LevelDB.DBCollection import DBCollection + from neo.Implementations.Blockchains.LevelDB.DBPrefix import DBPrefix + from neo.Core.State.AccountState import AccountState + from neo.Core.State.AssetState import AssetState + from neo.Core.State.ValidatorState import ValidatorState + from neo.Core.State.ContractState import ContractState + from neo.Core.State.StorageItem import StorageItem + + bc = GetBlockchain() + + accounts = DBCollection(bc._db, DBPrefix.ST_Account, AccountState) + assets = DBCollection(bc._db, DBPrefix.ST_Asset, AssetState) + validators = DBCollection(bc._db, DBPrefix.ST_Validator, ValidatorState) + contracts = DBCollection(bc._db, DBPrefix.ST_Contract, ContractState) + storages = DBCollection(bc._db, DBPrefix.ST_Storage, StorageItem) + + return StateMachine(accounts, validators, assets, contracts, storages, None) def GetConsensusAddress(validators): diff --git a/neo/Core/Block.py b/neo/Core/Block.py index a628a921c..976384617 100644 --- a/neo/Core/Block.py +++ b/neo/Core/Block.py @@ -17,21 +17,6 @@ class Block(BlockBase, InventoryMixin): - # < summary > - # 交易列表 - # < / summary > - Transactions = [] - - # < summary > - # 该区块的区块头 - # < / summary > - - _header = None - - __is_trimmed = False - # < summary > - # 资产清单的类型 - # < / summary > InventoryType = InventoryType.Block def __init__(self, prevHash=None, timestamp=None, index=None, @@ -59,6 +44,8 @@ def __init__(self, prevHash=None, timestamp=None, index=None, self.ConsensusData = consensusData self.NextConsensus = nextConsensus self.Script = script + self._header = None + self.__is_trimmed = False if transactions: self.Transactions = transactions @@ -232,7 +219,8 @@ def FromTrimmedData(byts): for tx_hash in reader.ReadHashes(): tx = bc.GetTransaction(tx_hash)[0] if not tx: - raise Exception("Could not find transaction!\n Are you running code against a valid Blockchain instance?\n Tests that accesses transactions or size of a block but inherit from NeoTestCase instead of BlockchainFixtureTestCase will not work.") + raise Exception( + "Could not find transaction!\n Are you running code against a valid Blockchain instance?\n Tests that accesses transactions or size of a block but inherit from NeoTestCase instead of BlockchainFixtureTestCase will not work.") tx_list.append(tx) if len(tx_list) < 1: diff --git a/neo/Core/BlockBase.py b/neo/Core/BlockBase.py index d6c199707..c37bf832b 100644 --- a/neo/Core/BlockBase.py +++ b/neo/Core/BlockBase.py @@ -10,40 +10,19 @@ class BlockBase(VerifiableMixin): - # - # 区块版本 - # - Version = 0 - # - # 前一个区块的散列值 - # - PrevHash = 0 # UInt256 - # - # 该区块中所有交易的Merkle树的根 - # - MerkleRoot = 0 # UInt256 - # - # 时间戳 - # - Timestamp = None - # - # 区块高度 - # - Index = 0 - - ConsensusData = None - # - # 下一个区块的记账合约的散列值 - # - NextConsensus = None # UInt160 - # - # 用于验证该区块的脚本 - # - Script = None - - __hash = None - - __htbs = None + def __init__(self): + + self.Version = 0 + self.PrevHash = 0 # UInt256 + self.MerkleRoot = 0 # UInt256 + self.Timestamp = None + self.Index = 0 + + self.ConsensusData = None + self.NextConsensus = None # UInt160 + self.Script = None + self.__hash = None + self.__htbs = None @property def Hash(self): diff --git a/neo/Core/Blockchain.py b/neo/Core/Blockchain.py index 2a74ad387..9027a9d27 100644 --- a/neo/Core/Blockchain.py +++ b/neo/Core/Blockchain.py @@ -1,7 +1,7 @@ import pytz from itertools import groupby from datetime import datetime -from events import Events +from neo.Network.common import Events from neo.Core.Block import Block from neo.Core.TX.Transaction import TransactionOutput from neo.Core.AssetType import AssetType @@ -19,6 +19,7 @@ from neo.Core.Cryptography.ECCurve import ECDSA from neo.Core.UInt256 import UInt256 from functools import lru_cache +from neo.Network.common import msgrouter from typing import TYPE_CHECKING, Optional @@ -454,6 +455,7 @@ def IsDoubleSpend(self, tx): def OnPersistCompleted(self, block): self.PersistCompleted.on_change(block) + msgrouter.on_block_persisted(block) def BlockCacheCount(self): pass diff --git a/neo/Core/CoinReference.py b/neo/Core/CoinReference.py index d9179c719..72f632021 100644 --- a/neo/Core/CoinReference.py +++ b/neo/Core/CoinReference.py @@ -5,9 +5,6 @@ class CoinReference(SerializableMixin): - PrevHash = None - - PrevIndex = None def __init__(self, prev_hash=None, prev_index=None): """ diff --git a/neo/Core/Cryptography/MerkleTree.py b/neo/Core/Cryptography/MerkleTree.py index 09e3b56ca..b423c6308 100644 --- a/neo/Core/Cryptography/MerkleTree.py +++ b/neo/Core/Cryptography/MerkleTree.py @@ -4,10 +4,6 @@ class MerkleTreeNode(object): - Hash = None - Parent = None - LeftChild = None - RightChild = None def __init__(self, hash=None): """ @@ -17,6 +13,9 @@ def __init__(self, hash=None): hash (bytes): """ self.Hash = hash + self.Parent = None + self.LeftChild = None + self.RightChild = None def IsLeaf(self): """ diff --git a/neo/Core/Fixed8.py b/neo/Core/Fixed8.py index 4fee1d832..901992e7f 100644 --- a/neo/Core/Fixed8.py +++ b/neo/Core/Fixed8.py @@ -11,8 +11,6 @@ class Fixed8: - value = 0 - D = 100000000 """docstring for Fixed8""" diff --git a/neo/Core/FunctionCode.py b/neo/Core/FunctionCode.py index e9350859d..0f73135cb 100644 --- a/neo/Core/FunctionCode.py +++ b/neo/Core/FunctionCode.py @@ -5,15 +5,6 @@ class FunctionCode(SerializableMixin): - Script = bytearray() - - ParameterList = bytearray() - - ReturnType = None - - _scriptHash = None - - ContractProperties = None @property def ReturnTypeBigInteger(self): @@ -53,6 +44,7 @@ def IsPayable(self): return self.ContractProperties & ContractPropertyState.Payable > 0 def __init__(self, script=None, param_list=None, return_type=255, contract_properties=0): + self._scriptHash = None self.Script = script if param_list is None: self.ParameterList = [] diff --git a/neo/Core/Helper.py b/neo/Core/Helper.py index 44111a806..2acd6ecb8 100644 --- a/neo/Core/Helper.py +++ b/neo/Core/Helper.py @@ -1,6 +1,6 @@ from base58 import b58decode import binascii -from neo.Blockchain import GetBlockchain, GetStateReader +from neo.Blockchain import GetBlockchain, GetStateMachine from neo.Implementations.Blockchains.LevelDB.CachedScriptTable import CachedScriptTable from neo.Implementations.Blockchains.LevelDB.DBCollection import DBCollection from neo.Implementations.Blockchains.LevelDB.DBPrefix import DBPrefix @@ -194,35 +194,36 @@ def VerifyScripts(verifiable): sb = ScriptBuilder() sb.EmitAppCall(hashes[i].Data) verification = sb.ms.getvalue() + sb.ms.Cleanup() else: verification_hash = Crypto.ToScriptHash(verification, unhex=False) if hashes[i] != verification_hash: logger.debug(f"hash {hashes[i]} does not match verification hash {verification_hash}") return False - state_reader = GetStateReader() + service = GetStateMachine() script_table = CachedScriptTable(DBCollection(blockchain._db, DBPrefix.ST_Contract, ContractState)) - engine = ApplicationEngine(TriggerType.Verification, verifiable, script_table, state_reader, Fixed8.Zero()) + engine = ApplicationEngine(TriggerType.Verification, verifiable, script_table, service, Fixed8.Zero()) engine.LoadScript(verification) invocation = verifiable.Scripts[i].InvocationScript engine.LoadScript(invocation) try: success = engine.Execute() - state_reader.ExecutionCompleted(engine, success) + service.ExecutionCompleted(engine, success) except Exception as e: - state_reader.ExecutionCompleted(engine, False, e) + service.ExecutionCompleted(engine, False, e) if engine.ResultStack.Count != 1 or not engine.ResultStack.Pop().GetBoolean(): - Helper.EmitServiceEvents(state_reader) + Helper.EmitServiceEvents(service) if engine.ResultStack.Count > 0: logger.debug(f"Result stack failure! Count: {engine.ResultStack.Count} bool value: {engine.ResultStack.Pop().GetBoolean()}") else: logger.debug(f"Result stack failure! Count: {engine.ResultStack.Count}") return False - Helper.EmitServiceEvents(state_reader) + Helper.EmitServiceEvents(service) return True diff --git a/neo/Core/IO/Mixins.py b/neo/Core/IO/Mixins.py index bbeeacf40..ca6b38937 100644 --- a/neo/Core/IO/Mixins.py +++ b/neo/Core/IO/Mixins.py @@ -20,5 +20,6 @@ def ToArray(self): class TrackableMixin(object): - Key = None - TrackingState = None + def __init__(self): + self.Key = None + self.TrackingState = None diff --git a/neo/Core/KeyPair.py b/neo/Core/KeyPair.py index 3e814d2cc..fda4add46 100644 --- a/neo/Core/KeyPair.py +++ b/neo/Core/KeyPair.py @@ -20,12 +20,6 @@ class KeyPair(object): - PublicKeyHash = None - - PublicKey = None - - PrivateKey = None - def setup_curve(self): """ Setup the Elliptic curve parameters. @@ -52,6 +46,9 @@ def __init__(self, priv_key): if the input `priv_key` length is 32 but the public key still could not be determined """ self.setup_curve() + self.PublicKeyHash = None + self.PublicKey = None + self.PrivateKey = None length = len(priv_key) diff --git a/neo/Core/Mixins.py b/neo/Core/Mixins.py index 22eab89d2..ec9eb1180 100644 --- a/neo/Core/Mixins.py +++ b/neo/Core/Mixins.py @@ -9,15 +9,18 @@ def clone(self): class CodeMixin: - scripts = [] - parameter_list = [] - return_type = None - script_hash = None + def __init__(self): + self.scripts = [] + self.parameter_list = [] + self.return_type = None + self.script_hash = None class VerifiableMixin(ABC, SerializableMixin): - scripts = [] + def __init__(self): + super(VerifiableMixin, self).__init__() + self.scripts = [] # # 反序列化未签名的数据 diff --git a/neo/Core/State/AccountState.py b/neo/Core/State/AccountState.py index a6fa95b0a..976d5e23c 100644 --- a/neo/Core/State/AccountState.py +++ b/neo/Core/State/AccountState.py @@ -11,11 +11,6 @@ class AccountState(StateBase): - ScriptHash = None - IsFrozen = False - Votes = [] - Balances = {} - def __init__(self, script_hash=None, is_frozen=False, votes=None, balances=None): """ Create an instance. diff --git a/neo/Core/State/AssetState.py b/neo/Core/State/AssetState.py index 551c43138..e24c8c242 100644 --- a/neo/Core/State/AssetState.py +++ b/neo/Core/State/AssetState.py @@ -11,27 +11,12 @@ class AssetState(StateBase): - AssetId = None - AssetType = None - Name = None - Amount = Fixed8(0) - Available = Fixed8(0) - Precision = 0 - FeeMode = 0 - Fee = Fixed8(0) - FeeAddress = None - Owner = None - Admin = None - Issuer = None - Expiration = None - IsFrozen = False - def Size(self): return super(AssetState, self).Size() + s.uint256 + s.uint8 + GetVarSize( self.Name) + self.Amount.Size() + self.Available.Size() + s.uint8 + s.uint8 + self.Fee.Size() + s.uint160 + self.Owner.Size() + s.uint160 + s.uint160 + s.uint32 + s.uint8 - def __init__(self, asset_id=None, asset_type=None, name=None, amount=Fixed8(0), available=Fixed8(0), - precision=0, fee_mode=0, fee=Fixed8(0), fee_addr=UInt160(data=bytearray(20)), owner=None, + def __init__(self, asset_id=None, asset_type=None, name=None, amount=None, available=None, + precision=0, fee_mode=0, fee=None, fee_addr=None, owner=None, admin=None, issuer=None, expiration=None, is_frozen=False): """ Create an instance. @@ -56,12 +41,12 @@ def __init__(self, asset_id=None, asset_type=None, name=None, amount=Fixed8(0), self.AssetType = asset_type self.Name = name - self.Amount = amount - self.Available = available + self.Amount = Fixed8(0) if amount is None else amount + self.Available = Fixed8(0) if available is None else available self.Precision = precision self.FeeMode = fee_mode - self.Fee = fee - self.FeeAddress = fee_addr + self.Fee = Fixed8(0) if fee is None else fee + self.FeeAddress = UInt160(data=bytearray(20)) if fee_addr is None else fee_addr if owner is not None and type(owner) is not EllipticCurve.ECPoint: raise Exception("Owner must be ECPoint Instance") diff --git a/neo/Core/State/ContractState.py b/neo/Core/State/ContractState.py index efd2353f7..c84256eab 100644 --- a/neo/Core/State/ContractState.py +++ b/neo/Core/State/ContractState.py @@ -17,17 +17,6 @@ class ContractPropertyState(IntEnum): class ContractState(StateBase): - Code = None - ContractProperties = None - Name = None - CodeVersion = None - Author = None - Email = None - Description = None - - _is_nep5 = None - _nep_token = None - @property def HasStorage(self): """ @@ -83,6 +72,9 @@ def __init__(self, code=None, contract_properties=0, name=None, version=None, au email (bytes): description (bytes): """ + self._is_nep5 = None + self._nep_token = None + self.Code = code self.ContractProperties = contract_properties self.Name = name diff --git a/neo/Core/State/SpentCoinState.py b/neo/Core/State/SpentCoinState.py index dc6de45ac..545524430 100644 --- a/neo/Core/State/SpentCoinState.py +++ b/neo/Core/State/SpentCoinState.py @@ -5,9 +5,6 @@ class SpentCoinItem: - index = None - height = None - def __init__(self, index, height): """ Create an instance. @@ -21,10 +18,6 @@ def __init__(self, index, height): class SpentCoin: - Output = None - StartHeight = None - EndHeight = None - @property def Value(self): """ @@ -74,14 +67,6 @@ def ToJson(self): class SpentCoinState(StateBase): - Output = None - StartHeight = None - EndHeight = None - - TransactionHash = None - TransactionHeight = None - Items = [] - def __init__(self, hash=None, height=None, items=None): """ Create an instance. @@ -98,6 +83,10 @@ def __init__(self, hash=None, height=None, items=None): else: self.Items = items + self.Output = None + self.StartHeight = None + self.EndHeight = None + def HasIndex(self, index): """ Flag indicating the index exists in any of the spent coin items. diff --git a/neo/Core/State/StateDescriptor.py b/neo/Core/State/StateDescriptor.py index 91b213e22..5a766bf5c 100644 --- a/neo/Core/State/StateDescriptor.py +++ b/neo/Core/State/StateDescriptor.py @@ -15,10 +15,13 @@ class StateType(Enum): class StateDescriptor(SerializableMixin): - Type = None - Key = None # byte[] - Field = None # string - Value = None # byte[] + + def __init__(self): + super().__init__() + self.Type = None + self.Key = None # byte[] + self.Field = None # string + self.Value = None # byte[] @property def SystemFee(self): @@ -145,8 +148,8 @@ def Verify(self): raise Exception("Invalid State Descriptor") def VerifyAccountState(self): - # @TODO - # Implement VerifyAccount State + # TODO + # Implement VerifyAccount State raise NotImplementedError() def VerifyValidatorState(self): diff --git a/neo/Core/State/StorageItem.py b/neo/Core/State/StorageItem.py index 6429fa68c..9fab4696b 100644 --- a/neo/Core/State/StorageItem.py +++ b/neo/Core/State/StorageItem.py @@ -5,7 +5,6 @@ class StorageItem(StateBase): - Value = None def __init__(self, value=None): """ diff --git a/neo/Core/State/StorageKey.py b/neo/Core/State/StorageKey.py index 612099017..659429236 100644 --- a/neo/Core/State/StorageKey.py +++ b/neo/Core/State/StorageKey.py @@ -4,8 +4,6 @@ class StorageKey(): - ScriptHash = None - Key = None def __init__(self, script_hash=None, key=None): """ diff --git a/neo/Core/State/UnspentCoinState.py b/neo/Core/State/UnspentCoinState.py index 31a3fb506..41b65291f 100644 --- a/neo/Core/State/UnspentCoinState.py +++ b/neo/Core/State/UnspentCoinState.py @@ -9,8 +9,6 @@ class UnspentCoinState(StateBase): - Items = None - def __init__(self, items=None): """ Create an instance. diff --git a/neo/Core/State/ValidatorState.py b/neo/Core/State/ValidatorState.py index c08eaabec..b4e844fd8 100644 --- a/neo/Core/State/ValidatorState.py +++ b/neo/Core/State/ValidatorState.py @@ -9,10 +9,6 @@ class ValidatorState(StateBase): - PublicKey = None # ECPoint - Registered = False # bool - Votes = Fixed8.Zero() - def __init__(self, pub_key=None): """ Create an instance. @@ -27,6 +23,8 @@ def __init__(self, pub_key=None): raise Exception("Pubkey must be ECPoint Instance") self.PublicKey = pub_key + self.Registered = False + self.Votes = Fixed8.Zero() def Size(self): """ diff --git a/neo/Core/TX/ClaimTransaction.py b/neo/Core/TX/ClaimTransaction.py index aced9b947..b7434175d 100644 --- a/neo/Core/TX/ClaimTransaction.py +++ b/neo/Core/TX/ClaimTransaction.py @@ -10,8 +10,6 @@ class ClaimTransaction(Transaction): - Claims = set() - def Size(self): """ Get the total size in bytes of the object. @@ -32,6 +30,7 @@ def __init__(self, *args, **kwargs): super(ClaimTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.ClaimTransaction + self.Claims = set() def NetworkFee(self): """ diff --git a/neo/Core/TX/EnrollmentTransaction.py b/neo/Core/TX/EnrollmentTransaction.py index 72375b313..b77533048 100644 --- a/neo/Core/TX/EnrollmentTransaction.py +++ b/neo/Core/TX/EnrollmentTransaction.py @@ -3,8 +3,6 @@ class EnrollmentTransaction(Transaction): - PublicKey = None - _script_hash = None def __init__(self, *args, **kwargs): """ @@ -16,6 +14,8 @@ def __init__(self, *args, **kwargs): """ super(EnrollmentTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.EnrollmentTransaction + self.PublicKey = None + self._script_hash = None def Size(self): """ diff --git a/neo/Core/TX/InvocationTransaction.py b/neo/Core/TX/InvocationTransaction.py index 07eb477de..7b3078c88 100644 --- a/neo/Core/TX/InvocationTransaction.py +++ b/neo/Core/TX/InvocationTransaction.py @@ -5,9 +5,6 @@ class InvocationTransaction(Transaction): - Script = None - Gas = None - def SystemFee(self): """ Get the system fee. @@ -28,6 +25,7 @@ def __init__(self, *args, **kwargs): super(InvocationTransaction, self).__init__(*args, **kwargs) self.Gas = Fixed8(0) self.Type = TransactionType.InvocationTransaction + self.Script = None def Size(self): """ diff --git a/neo/Core/TX/IssueTransaction.py b/neo/Core/TX/IssueTransaction.py index 40242ef06..24b0884fb 100644 --- a/neo/Core/TX/IssueTransaction.py +++ b/neo/Core/TX/IssueTransaction.py @@ -10,8 +10,6 @@ class IssueTransaction(Transaction): - Nonce = None - """docstring for IssueTransaction""" def __init__(self, *args, **kwargs): @@ -24,6 +22,7 @@ def __init__(self, *args, **kwargs): """ super(IssueTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.IssueTransaction # 0x40 + self.Nonce = None def SystemFee(self): """ diff --git a/neo/Core/TX/MinerTransaction.py b/neo/Core/TX/MinerTransaction.py index 437f3b9ed..cfb6390d2 100644 --- a/neo/Core/TX/MinerTransaction.py +++ b/neo/Core/TX/MinerTransaction.py @@ -6,7 +6,6 @@ class MinerTransaction(Transaction): - Nonce = None def __init__(self, *args, **kwargs): """ @@ -18,6 +17,7 @@ def __init__(self, *args, **kwargs): """ super(MinerTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.MinerTransaction + self.Nonce = None def NetworkFee(self): """ diff --git a/neo/Core/TX/PublishTransaction.py b/neo/Core/TX/PublishTransaction.py index bd45b2eed..3be66f3b3 100644 --- a/neo/Core/TX/PublishTransaction.py +++ b/neo/Core/TX/PublishTransaction.py @@ -8,14 +8,6 @@ class PublishTransaction(Transaction): - Code = None - NeedStorage = False - Name = '' - CodeVersion = '' - Author = '' - Email = '' - Description = '' - def __init__(self, *args, **kwargs): """ Create instance. @@ -26,6 +18,13 @@ def __init__(self, *args, **kwargs): """ super(PublishTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.PublishTransaction + self.Code = None + self.NeedStorage = False + self.Name = '' + self.CodeVersion = '' + self.Author = '' + self.Email = '' + self.Description = '' def Size(self): """ @@ -34,7 +33,8 @@ def Size(self): Returns: int: size. """ - return super(PublishTransaction, self).Size() + GetVarSize(self.Code.Script) + GetVarSize(self.Code.ParameterList) + s.uint8 + GetVarSize(self.Name) + GetVarSize(self.CodeVersion) + GetVarSize(self.Author) + GetVarSize(self.Email) + GetVarSize(self.Description) + return super(PublishTransaction, self).Size() + GetVarSize(self.Code.Script) + GetVarSize(self.Code.ParameterList) + s.uint8 + GetVarSize( + self.Name) + GetVarSize(self.CodeVersion) + GetVarSize(self.Author) + GetVarSize(self.Email) + GetVarSize(self.Description) def DeserializeExclusiveData(self, reader): """ diff --git a/neo/Core/TX/StateTransaction.py b/neo/Core/TX/StateTransaction.py index c38c8175b..50ecba311 100644 --- a/neo/Core/TX/StateTransaction.py +++ b/neo/Core/TX/StateTransaction.py @@ -4,7 +4,6 @@ class StateTransaction(Transaction): - Descriptors = None def Size(self): """ @@ -27,6 +26,7 @@ def __init__(self, *args, **kwargs): super(StateTransaction, self).__init__(*args, **kwargs) self.Type = TransactionType.StateTransaction + self.Descriptors = None def SystemFee(self): amount = Fixed8.Zero() diff --git a/neo/Core/TX/Transaction.py b/neo/Core/TX/Transaction.py index c837d7605..123c23fc4 100644 --- a/neo/Core/TX/Transaction.py +++ b/neo/Core/TX/Transaction.py @@ -30,8 +30,6 @@ class TransactionResult(EquatableMixin): - AssetId = None - Amount = Fixed8(0) def __init__(self, asset_id, amount): """ @@ -78,10 +76,6 @@ def ToName(value): class TransactionOutput(SerializableMixin, EquatableMixin): - Value = None # should be fixed 8 - ScriptHash = None - AssetId = None - """docstring for TransactionOutput""" def __init__(self, AssetId=None, Value=None, script_hash=None): @@ -168,9 +162,6 @@ def Size(self): class TransactionInput(SerializableMixin, EquatableMixin): """docstring for TransactionInput""" - PrevHash = None - PrevIndex = None - def __init__(self, prevHash=None, prevIndex=None): """ Create an instance. @@ -226,7 +217,6 @@ def ToJson(self): class Transaction(InventoryMixin): Version = 0 - __system_fee = None InventoryType = InventoryType.TX MAX_TX_ATTRIBUTES = 16 @@ -250,6 +240,7 @@ def __init__(self, inputs=None, outputs=None, attributes=None, scripts=None): self.raw_tx = False self.withdraw_hold = None self._network_fee = None + self.__system_fee = None self.__hash = None self.__htbs = None self.__height = 0 @@ -405,7 +396,7 @@ def Deserialize(self, reader): """ self.DeserializeUnsigned(reader) - self.scripts = reader.ReadSerializableArray() + self.scripts = reader.ReadSerializableArray('neo.Core.Witness.Witness') self.OnDeserialized() def DeserializeExclusiveData(self, reader): @@ -596,7 +587,7 @@ def Verify(self, mempool): Returns: bool: True if verified. False otherwise. """ - logger.info("Verifying transaction: %s " % self.Hash.ToBytes()) + logger.debug("Verifying transaction: %s " % self.Hash.ToBytes()) return Helper.VerifyScripts(self) diff --git a/neo/Core/UIntBase.py b/neo/Core/UIntBase.py index ad2d2ff16..a570deb48 100644 --- a/neo/Core/UIntBase.py +++ b/neo/Core/UIntBase.py @@ -4,8 +4,6 @@ class UIntBase(SerializableMixin): - Data = bytearray() - __hash = None def __init__(self, num_bytes, data=None): """ @@ -20,7 +18,7 @@ def __init__(self, num_bytes, data=None): TypeError: if the input `data` is not bytes or bytearray """ super(UIntBase, self).__init__() - + self.__hash = None if data is None: self.Data = bytearray(num_bytes) diff --git a/neo/Core/VerificationCode.py b/neo/Core/VerificationCode.py index 84cb8e56e..08cc31a83 100644 --- a/neo/Core/VerificationCode.py +++ b/neo/Core/VerificationCode.py @@ -7,14 +7,6 @@ class VerificationCode: - Script = None - - ParameterList = None - - ReturnType = ContractParameterType.Boolean - - _scriptHash = None - @property def ScriptHash(self): @@ -31,3 +23,5 @@ def ScriptHash(self): def __init__(self, script=None, param_list=None): self.Script = script self.ParameterList = param_list + self.ReturnType = ContractParameterType.Boolean + self._scriptHash = None diff --git a/neo/Core/Witness.py b/neo/Core/Witness.py index 19201b59a..72e97f984 100644 --- a/neo/Core/Witness.py +++ b/neo/Core/Witness.py @@ -1,12 +1,9 @@ -import sys import binascii from neo.Core.IO.Mixins import SerializableMixin from neo.Core.Size import GetVarSize class Witness(SerializableMixin): - InvocationScript = None - VerificationScript = None def __init__(self, invocation_script=None, verification_script=None): try: @@ -36,17 +33,10 @@ def Deserialize(self, reader): self.VerificationScript = reader.ReadVarBytes() def Serialize(self, writer): - # logger.info("Serializing Witnes.....") - # logger.info("INVOCATION %s " % self.InvocationScript) writer.WriteVarBytes(self.InvocationScript) - # logger.info("writer after invocation %s " % writer.stream.ToArray()) - # logger.info("Now wringi verificiation script %s " % self.VerificationScript) writer.WriteVarBytes(self.VerificationScript) - # logger.info("Wrote verification script %s " % writer.stream.ToArray()) - def ToJson(self): - # logger.info("invocation %s " % self.InvocationScript) data = { 'invocation': self.InvocationScript.hex(), 'verification': self.VerificationScript.hex() diff --git a/neo/IO/Helper.py b/neo/IO/Helper.py index 61b4fd7c9..09803477f 100644 --- a/neo/IO/Helper.py +++ b/neo/IO/Helper.py @@ -1,5 +1,5 @@ import importlib -from .MemoryStream import MemoryStream, StreamManager +from .MemoryStream import StreamManager from neo.Core.IO.BinaryReader import BinaryReader from neo.Core.TX.Transaction import Transaction from neo.logging import log_manager @@ -49,9 +49,9 @@ def DeserializeTX(buffer): Returns: neo.Core.TX.Transaction: """ - mstream = MemoryStream(buffer) + mstream = StreamManager.GetStream(buffer) reader = BinaryReader(mstream) tx = Transaction.DeserializeFrom(reader) - + mstream.Cleanup() return tx diff --git a/neo/Implementations/Blockchains/LevelDB/CachedScriptTable.py b/neo/Implementations/Blockchains/LevelDB/CachedScriptTable.py index 4c73ada35..4f2c54b12 100644 --- a/neo/Implementations/Blockchains/LevelDB/CachedScriptTable.py +++ b/neo/Implementations/Blockchains/LevelDB/CachedScriptTable.py @@ -3,13 +3,10 @@ class CachedScriptTable(ScriptTableMixin): - contracts = None - def __init__(self, contracts): self.contracts = contracts def GetScript(self, script_hash): - contract = self.contracts.TryGet(script_hash) if contract is not None: @@ -18,7 +15,6 @@ def GetScript(self, script_hash): return None def GetContractState(self, script_hash): - contract = self.contracts.TryGet(script_hash) return contract diff --git a/neo/Implementations/Blockchains/LevelDB/DBCollection.py b/neo/Implementations/Blockchains/LevelDB/DBCollection.py index b659b43c1..964dbc2e2 100644 --- a/neo/Implementations/Blockchains/LevelDB/DBCollection.py +++ b/neo/Implementations/Blockchains/LevelDB/DBCollection.py @@ -6,24 +6,10 @@ class DBCollection: - DB = None - Prefix = None - - ClassRef = None - - Collection = {} - - Changed = [] - Deleted = [] - - _built_keys = False - - DebugStorage = False - - _ChangedResetState = None - _DeletedResetState = None def __init__(self, db, prefix, class_ref): + self._built_keys = False + self.DebugStorage = False self.DB = db @@ -59,10 +45,11 @@ def Current(self): return {} def _BuildCollectionKeys(self): - for key in self.DB.iterator(prefix=self.Prefix, include_value=False): - key = key[1:] - if key not in self.Collection.keys(): - self.Collection[key] = None + with self.DB.iterator(prefix=self.Prefix, include_value=False) as it: + for key in it: + key = key[1:] + if key not in self.Collection.keys(): + self.Collection[key] = None def Commit(self, wb, destroy=True): @@ -216,12 +203,13 @@ def TryFind(self, key_prefix): def Find(self, key_prefix): key_prefix = self.Prefix + key_prefix res = {} - for key, val in self.DB.iterator(prefix=key_prefix): - # we want the storage item, not the raw bytes - item = self.ClassRef.DeserializeFromDB(binascii.unhexlify(val)).Value - # also here we need to skip the 1 byte storage prefix - res_key = key[21:] - res[res_key] = item + with self.DB.iterator(prefix=key_prefix) as it: + for key, val in it: + # we want the storage item, not the raw bytes + item = self.ClassRef.DeserializeFromDB(binascii.unhexlify(val)).Value + # also here we need to skip the 1 byte storage prefix + res_key = key[21:] + res[res_key] = item return res def Destroy(self): diff --git a/neo/Implementations/Blockchains/LevelDB/DebugStorage.py b/neo/Implementations/Blockchains/LevelDB/DebugStorage.py index 61da6cb30..a05fb63b1 100644 --- a/neo/Implementations/Blockchains/LevelDB/DebugStorage.py +++ b/neo/Implementations/Blockchains/LevelDB/DebugStorage.py @@ -15,13 +15,15 @@ def db(self): return self._db def reset(self): - for key in self._db.iterator(prefix=DBPrefix.ST_Storage, include_value=False): - self._db.delete(key) + with self._db.iterator(prefix=DBPrefix.ST_Storage, include_value=False) as it: + for key in it: + self._db.delete(key) def clone_from_live(self): clone_db = GetBlockchain()._db.snapshot() - for key, value in clone_db.iterator(prefix=DBPrefix.ST_Storage, include_value=True): - self._db.put(key, value) + with clone_db.iterator(prefix=DBPrefix.ST_Storage, include_value=True) as it: + for key, value in it: + self._db.put(key, value) def __init__(self): diff --git a/neo/Implementations/Blockchains/LevelDB/LevelDBBlockchain.py b/neo/Implementations/Blockchains/LevelDB/LevelDBBlockchain.py index e75c4b709..f273f949d 100644 --- a/neo/Implementations/Blockchains/LevelDB/LevelDBBlockchain.py +++ b/neo/Implementations/Blockchains/LevelDB/LevelDBBlockchain.py @@ -1,6 +1,8 @@ import plyvel +import asyncio import binascii import struct +import traceback from neo.Core.Blockchain import Blockchain from neo.Core.Header import Header from neo.Core.Block import Block @@ -28,10 +30,11 @@ from neo.SmartContract.ApplicationEngine import ApplicationEngine from neo.SmartContract import TriggerType from neo.Core.Cryptography.Crypto import Crypto -from neo.Core.BigInteger import BigInteger from neo.EventHub import events +from typing import Tuple -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt +from neo.Network.common import wait_for from neo.logging import log_manager logger = log_manager.getLogger('db') @@ -53,7 +56,7 @@ class LevelDBBlockchain(Blockchain): # this is the version of the database # should not be updated for network version changes - _sysversion = b'schema v.0.6.9' + _sysversion = b'schema v.0.8.5' _persisting_block = None @@ -109,7 +112,7 @@ def __init__(self, path, skip_version_check=False, skip_header_check=False): self.TXProcessed = 0 try: - self._db = plyvel.DB(self._path, create_if_missing=True) + self._db = plyvel.DB(self._path, create_if_missing=True, max_open_files=100, lru_cache_size=10 * 1024 * 1024) logger.info("Created Blockchain DB at %s " % self._path) except Exception as e: logger.info("leveldb unavailable, you may already be running this process: %s " % e) @@ -133,13 +136,14 @@ def __init__(self, path, skip_version_check=False, skip_header_check=False): hashes = [] try: - for key, value in self._db.iterator(prefix=DBPrefix.IX_HeaderHashList): - ms = StreamManager.GetStream(value) - reader = BinaryReader(ms) - hlist = reader.Read2000256List() - key = int.from_bytes(key[-4:], 'little') - hashes.append({'k': key, 'v': hlist}) - StreamManager.ReleaseStream(ms) + with self._db.iterator(prefix=DBPrefix.IX_HeaderHashList) as it: + for key, value in it: + ms = StreamManager.GetStream(value) + reader = BinaryReader(ms) + hlist = reader.Read2000256List() + key = int.from_bytes(key[-4:], 'little') + hashes.append({'k': key, 'v': hlist}) + StreamManager.ReleaseStream(ms) except Exception as e: logger.info("Could not get stored header hash list: %s " % e) @@ -156,9 +160,10 @@ def __init__(self, path, skip_version_check=False, skip_header_check=False): if self._stored_header_count == 0: logger.info("Current stored headers empty, re-creating from stored blocks...") headers = [] - for key, value in self._db.iterator(prefix=DBPrefix.DATA_Block): - dbhash = bytearray(value)[8:] - headers.append(Header.FromTrimmedData(binascii.unhexlify(dbhash), 0)) + with self._db.iterator(prefix=DBPrefix.DATA_Block) as it: + for key, value in it: + dbhash = bytearray(value)[8:] + headers.append(Header.FromTrimmedData(binascii.unhexlify(dbhash), 0)) headers.sort(key=lambda h: h.Index) for h in headers: @@ -186,7 +191,8 @@ def __init__(self, path, skip_version_check=False, skip_header_check=False): pass elif version is None: - self.Persist(Blockchain.GenesisBlock()) + + wait_for(self.Persist(Blockchain.GenesisBlock())) self._db.put(DBPrefix.SYS_Version, self._sysversion) else: logger.error("\n\n") @@ -200,8 +206,9 @@ def __init__(self, path, skip_version_check=False, skip_header_check=False): if res == 'continue': with self._db.write_batch() as wb: - for key, value in self._db.iterator(): - wb.delete(key) + with self._db.iterator() as it: + for key, value in it: + wb.delete(key) self.Persist(Blockchain.GenesisBlock()) self._db.put(DBPrefix.SYS_Version, self._sysversion) @@ -589,7 +596,7 @@ def GetNextBlockHash(self, hash): return None def AddHeader(self, header): - self.AddHeaders([header]) + return self.AddHeaders([header]) def AddHeaders(self, headers): @@ -614,7 +621,7 @@ def AddHeaders(self, headers): if len(newheaders): self.ProcessNewHeaders(newheaders) - return True + return count def ProcessNewHeaders(self, headers): @@ -655,10 +662,7 @@ def OnAddHeader(self, header): def BlockCacheCount(self): return len(self._block_cache) - def Persist(self, block): - - self._persisting_block = block - + async def Persist(self, block): accounts = DBCollection(self._db, DBPrefix.ST_Account, AccountState) unspentcoins = DBCollection(self._db, DBPrefix.ST_Coin, UnspentCoinState) spentcoins = DBCollection(self._db, DBPrefix.ST_SpentCoin, SpentCoinState) @@ -676,7 +680,6 @@ def Persist(self, block): wb.put(DBPrefix.DATA_Block + block.Hash.ToBytes(), amount_sysfee_bytes + block.Trim()) for tx in block.Transactions: - wb.put(DBPrefix.DATA_Transaction + tx.Hash.ToBytes(), block.IndexBytes() + tx.ToArray()) # go through all outputs and add unspent coins to them @@ -778,6 +781,7 @@ def Persist(self, block): service.ExecutionCompleted(engine, False, e) to_dispatch = to_dispatch + service.events_to_dispatch + await asyncio.sleep(0.001) else: if tx.Type != b'\x00' and tx.Type != 128: @@ -816,13 +820,29 @@ def Persist(self, block): wb.put(DBPrefix.SYS_CurrentBlock, block.Hash.ToBytes() + block.IndexBytes()) self._current_block_height = block.Index - self._persisting_block = None self.TXProcessed += len(block.Transactions) for event in to_dispatch: events.emit(event.event_type, event) + async def TryPersist(self, block: 'Block') -> Tuple[bool, str]: + distance = self._current_block_height - block.Index + + if distance >= 0: + return False, "Block already exists" + + if distance < -1: + return False, f"Trying to persist block {block.Index} but expecting next block to be {self._current_block_height + 1}" + + try: + await self.Persist(block) + except Exception as e: + traceback.print_exc() + return False, f"{e}" + + return True, "" + def PersistBlocks(self, limit=None): ctr = 0 if not self._paused: diff --git a/neo/Implementations/Notifications/LevelDB/NotificationDB.py b/neo/Implementations/Notifications/LevelDB/NotificationDB.py index c4d85344f..6a0d9698d 100644 --- a/neo/Implementations/Notifications/LevelDB/NotificationDB.py +++ b/neo/Implementations/Notifications/LevelDB/NotificationDB.py @@ -224,9 +224,10 @@ def get_by_block(self, block_number): blocklist_snapshot = self.db.prefixed_db(NotificationPrefix.PREFIX_BLOCK).snapshot() block_bytes = block_number.to_bytes(4, 'little') results = [] - for val in blocklist_snapshot.iterator(prefix=block_bytes, include_key=False): - event = SmartContractEvent.FromByteArray(val) - results.append(event) + with blocklist_snapshot.iterator(prefix=block_bytes, include_key=False) as it: + for val in it: + event = SmartContractEvent.FromByteArray(val) + results.append(event) return results @@ -249,13 +250,14 @@ def get_by_addr(self, address): addrlist_snapshot = self.db.prefixed_db(NotificationPrefix.PREFIX_ADDR).snapshot() results = [] - for val in addrlist_snapshot.iterator(prefix=bytes(addr.Data), include_key=False): - if len(val) > 4: - try: - event = SmartContractEvent.FromByteArray(val) - results.append(event) - except Exception as e: - logger.error("could not parse event: %s %s" % (e, val)) + with addrlist_snapshot.iterator(prefix=bytes(addr.Data), include_key=False) as it: + for val in it: + if len(val) > 4: + try: + event = SmartContractEvent.FromByteArray(val) + results.append(event) + except Exception as e: + logger.error("could not parse event: %s %s" % (e, val)) return results def get_by_contract(self, contract_hash): @@ -277,13 +279,14 @@ def get_by_contract(self, contract_hash): contractlist_snapshot = self.db.prefixed_db(NotificationPrefix.PREFIX_CONTRACT).snapshot() results = [] - for val in contractlist_snapshot.iterator(prefix=bytes(hash.Data), include_key=False): - if len(val) > 4: - try: - event = SmartContractEvent.FromByteArray(val) - results.append(event) - except Exception as e: - logger.error("could not parse event: %s %s" % (e, val)) + with contractlist_snapshot.iterator(prefix=bytes(hash.Data), include_key=False) as it: + for val in it: + if len(val) > 4: + try: + event = SmartContractEvent.FromByteArray(val) + results.append(event) + except Exception as e: + logger.error("could not parse event: %s %s" % (e, val)) return results def get_tokens(self): @@ -294,9 +297,10 @@ def get_tokens(self): """ tokens_snapshot = self.db.prefixed_db(NotificationPrefix.PREFIX_TOKEN).snapshot() results = [] - for val in tokens_snapshot.iterator(include_key=False): - event = SmartContractEvent.FromByteArray(val) - results.append(event) + with tokens_snapshot.iterator(include_key=False) as it: + for val in it: + event = SmartContractEvent.FromByteArray(val) + results.append(event) return results def get_token(self, hash): diff --git a/neo/Implementations/Wallets/peewee/UserWallet.py b/neo/Implementations/Wallets/peewee/UserWallet.py index 93a77caf7..291f2d4c7 100755 --- a/neo/Implementations/Wallets/peewee/UserWallet.py +++ b/neo/Implementations/Wallets/peewee/UserWallet.py @@ -31,12 +31,6 @@ class UserWallet(Wallet): Version = None - __dbaccount = None - - _aliases = None - - _db = None - def __init__(self, path, passwordKey, create): super(UserWallet, self).__init__(path, passwordKey=passwordKey, create=create) @@ -76,6 +70,8 @@ def Close(self): self._db.close() self._db = None + Blockchain.Default().PersistCompleted.on_change -= self.ProcessNewBlock + @staticmethod def Open(path, password): return UserWallet(path=path, passwordKey=password, create=False) @@ -573,7 +569,6 @@ def ToJson(self, verbose=False): addresses = [] has_watch_addr = False for addr in Address.select(): - logger.info("Script hash %s %s" % (addr.ScriptHash, type(addr.ScriptHash))) addr_str = Crypto.ToAddress(UInt160(data=addr.ScriptHash)) acct = Blockchain.Default().GetAccountState(addr_str) token_balances = self.TokenBalancesForAddress(addr_str) diff --git a/neo/Implementations/Wallets/peewee/test_user_wallet.py b/neo/Implementations/Wallets/peewee/test_user_wallet.py index 1a36ddbd5..b4111786b 100644 --- a/neo/Implementations/Wallets/peewee/test_user_wallet.py +++ b/neo/Implementations/Wallets/peewee/test_user_wallet.py @@ -11,8 +11,10 @@ from neo.Wallets.NEP5Token import NEP5Token from neo.SmartContract.ContractParameterContext import ContractParametersContext from neo.Core.TX.Transaction import ContractTransaction, TransactionOutput, TXFeeError -from neo.Network.NodeLeader import NodeLeader +from mock import patch import binascii +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode class UserWalletTestCase(WalletFixtureTestCase): @@ -187,20 +189,23 @@ def test_8_named_addr(self): self.assertEqual(presult, self.wallet_1_script_hash.Data) def test_9_send_neo_tx(self): - - wallet = self.GetWallet1() - - tx = ContractTransaction() - tx.outputs = [TransactionOutput(Blockchain.SystemShare().Hash, Fixed8.FromDecimal(10.0), self.import_watch_addr)] - - try: - tx = wallet.MakeTransaction(tx) - except (ValueError, TXFeeError): - pass - - cpc = ContractParametersContext(tx) - wallet.Sign(cpc) - tx.scripts = cpc.GetScripts() - - result = NodeLeader.Instance().Relay(tx) - self.assertEqual(result, True) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + wallet = self.GetWallet1() + + tx = ContractTransaction() + tx.outputs = [TransactionOutput(Blockchain.SystemShare().Hash, Fixed8.FromDecimal(10.0), self.import_watch_addr)] + + try: + tx = wallet.MakeTransaction(tx) + except (ValueError, TXFeeError): + pass + + cpc = ContractParametersContext(tx) + wallet.Sign(cpc) + tx.scripts = cpc.GetScripts() + + nodemgr = NodeManager() + # we need at least 1 node for relay to be mocked + nodemgr.nodes = [NeoNode(object, object)] + result = nodemgr.relay(tx) + self.assertEqual(result, True) diff --git a/neo/Network/Message.py b/neo/Network/Message.py deleted file mode 100644 index 72122a354..000000000 --- a/neo/Network/Message.py +++ /dev/null @@ -1,110 +0,0 @@ -import binascii -from neo.Core.IO.Mixins import SerializableMixin -from neo.Settings import settings -from neo.Core.Helper import Helper -from neo.Core.Cryptography.Helper import bin_dbl_sha256 -from neo.Core.Size import Size as s -from neo.logging import log_manager - -logger = log_manager.getLogger() - - -class ChecksumException(Exception): - pass - - -class Message(SerializableMixin): - PayloadMaxSize = b'\x02000000' - PayloadMaxSizeInt = int.from_bytes(PayloadMaxSize, 'big') - - Magic = None - - Command = None - - Checksum = None - - Payload = None - - Length = 0 - - def __init__(self, command=None, payload=None, print_payload=False): - """ - Create an instance. - - Args: - command (str): payload command e.g. "inv", "getdata". See NeoNode.MessageReceived() for more commands. - payload (bytes): raw bytes of the payload. - print_payload: UNUSED - """ - self.Command = command - self.Magic = settings.MAGIC - - if payload is None: - payload = bytearray() - else: - payload = binascii.unhexlify(Helper.ToArray(payload)) - - self.Checksum = Message.GetChecksum(payload) - self.Payload = payload - - if print_payload: - logger.info("PAYLOAD: %s " % self.Payload) - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - return s.uint32 + 12 + s.uint32 + s.uint32 + len(self.Payload) - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.Magic = reader.ReadUInt32() - self.Command = reader.ReadFixedString(12).decode('utf-8') - self.Length = reader.ReadUInt32() - - if self.Length > self.PayloadMaxSizeInt: - raise Exception("invalid format- payload too large") - - self.Checksum = reader.ReadUInt32() - self.Payload = reader.ReadBytes(self.Length) - - checksum = Message.GetChecksum(self.Payload) - - if checksum != self.Checksum: - raise ChecksumException("checksum mismatch") - - @staticmethod - def GetChecksum(value): - """ - Get the double SHA256 hash of the value. - - Args: - value (obj): a payload - - Returns: - - """ - uint32 = bin_dbl_sha256(value)[:4] - - return int.from_bytes(uint32, 'little') - - def Serialize(self, writer): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteUInt32(self.Magic) - writer.WriteFixedString(self.Command, 12) - writer.WriteUInt32(len(self.Payload)) - writer.WriteUInt32(self.Checksum) - writer.WriteBytes(self.Payload) diff --git a/neo/Network/Mixins.py b/neo/Network/Mixins.py index df320ce66..7f4b6e4aa 100644 --- a/neo/Network/Mixins.py +++ b/neo/Network/Mixins.py @@ -1,11 +1,11 @@ - from neo.Core.Mixins import VerifiableMixin class InventoryMixin(VerifiableMixin): - Hash = None - InventoryType = None + def __init__(self): + super(InventoryMixin, self).__init__() + self.InventoryType = None def Verify(self): pass diff --git a/neo/Network/NeoNode.py b/neo/Network/NeoNode.py deleted file mode 100644 index 6bbff60bb..000000000 --- a/neo/Network/NeoNode.py +++ /dev/null @@ -1,876 +0,0 @@ -import binascii -import random -import datetime -from twisted.internet.protocol import Protocol -from twisted.internet import error as twisted_error -from twisted.internet import reactor, task, defer -from twisted.internet.address import IPv4Address -from twisted.internet.defer import CancelledError -from twisted.internet import error -from neo.Core.Blockchain import Blockchain as BC -from neo.Core.IO.BinaryReader import BinaryReader -from neo.Network.Message import Message -from neo.IO.MemoryStream import StreamManager -from neo.IO.Helper import Helper as IOHelper -from neo.Core.Helper import Helper -from .Payloads.GetBlocksPayload import GetBlocksPayload -from .Payloads.InvPayload import InvPayload -from .Payloads.NetworkAddressWithTime import NetworkAddressWithTime -from .Payloads.VersionPayload import VersionPayload -from .Payloads.HeadersPayload import HeadersPayload -from .Payloads.AddrPayload import AddrPayload -from .InventoryType import InventoryType -from neo.Settings import settings -from neo.logging import log_manager -from neo.Network.address import Address - -logger = log_manager.getLogger('network') -logger_verbose = log_manager.getLogger('network.verbose') -MODE_MAINTAIN = 7 -MODE_CATCHUP = 2 - -mode_to_name = {MODE_CATCHUP: 'CATCHUP', MODE_MAINTAIN: 'MAINTAIN'} - -HEARTBEAT_BLOCKS = 'B' -HEARTBEAT_HEADERS = 'H' - - -class NeoNode(Protocol): - Version = None - - leader = None - - identifier = None - - def has_tasks_running(self): - block = False - header = False - peer = False - if self.block_loop and self.block_loop.running: - block = True - - if self.peer_loop and self.peer_loop.running: - peer = True - - if self.header_loop and self.header_loop.running: - header = True - - return block and header and peer - - def start_all_tasks(self): - if not self.disconnecting: - self.start_block_loop() - self.start_header_loop() - self.start_peerinfo_loop() - - def start_block_loop(self): - logger_verbose.debug(f"{self.prefix} start_block_loop") - if self.block_loop and self.block_loop.running: - logger_verbose.debug(f"start_block_loop: still running -> stopping...") - self.stop_block_loop() - self.block_loop = task.LoopingCall(self.AskForMoreBlocks) - self.block_loop_deferred = self.block_loop.start(self.sync_mode, now=False) - self.block_loop_deferred.addErrback(self.OnLoopError) - # self.leader.task_handles[self.block_loop] = self.prefix + f"{'block_loop':>15}" - - def stop_block_loop(self, cancel=True): - logger_verbose.debug(f"{self.prefix} stop_block_loop: cancel -> {cancel}") - if self.block_loop: - logger_verbose.debug(f"{self.prefix} self.block_loop true") - if self.block_loop.running: - logger_verbose.debug(f"{self.prefix} stop_block_loop, calling stop") - self.block_loop.stop() - if cancel and self.block_loop_deferred: - logger_verbose.debug(f"{self.prefix} stop_block_loop: trying to cancel") - self.block_loop_deferred.cancel() - - def start_peerinfo_loop(self): - logger_verbose.debug(f"{self.prefix} start_peerinfo_loop") - if self.peer_loop and self.peer_loop.running: - logger_verbose.debug(f"start_peer_loop: still running -> stopping...") - self.stop_peerinfo_loop() - self.peer_loop = task.LoopingCall(self.RequestPeerInfo) - self.peer_loop_deferred = self.peer_loop.start(120, now=False) - self.peer_loop_deferred.addErrback(self.OnLoopError) - # self.leader.task_handles[self.peer_loop] = self.prefix + f"{'peerinfo_loop':>15}" - - def stop_peerinfo_loop(self, cancel=True): - logger_verbose.debug(f"{self.prefix} stop_peerinfo_loop: cancel -> {cancel}") - if self.peer_loop and self.peer_loop.running: - logger_verbose.debug(f"{self.prefix} stop_peerinfo_loop, calling stop") - self.peer_loop.stop() - if cancel and self.peer_loop_deferred: - logger_verbose.debug(f"{self.prefix} stop_peerinfo_loop: trying to cancel") - self.peer_loop_deferred.cancel() - - def start_header_loop(self): - logger_verbose.debug(f"{self.prefix} start_header_loop") - if self.header_loop and self.header_loop.running: - logger_verbose.debug(f"start_header_loop: still running -> stopping...") - self.stop_header_loop() - self.header_loop = task.LoopingCall(self.AskForMoreHeaders) - self.header_loop_deferred = self.header_loop.start(5, now=False) - self.header_loop_deferred.addErrback(self.OnLoopError) - # self.leader.task_handles[self.header_loop] = self.prefix + f"{'header_loop':>15}" - - def stop_header_loop(self, cancel=True): - logger_verbose.debug(f"{self.prefix} stop_header_loop: cancel -> {cancel}") - if self.header_loop: - logger_verbose.debug(f"{self.prefix} self.header_loop true") - if self.header_loop.running: - logger_verbose.debug(f"{self.prefix} stop_header_loop, calling stop") - self.header_loop.stop() - if cancel and self.header_loop_deferred: - logger_verbose.debug(f"{self.prefix} stop_header_loop: trying to cancel") - self.header_loop_deferred.cancel() - - def __init__(self, incoming_client=False): - """ - Create an instance. - The NeoNode class is the equivalent of the C# RemoteNode.cs class. It represents a single Node connected to the client. - - Args: - incoming_client (bool): True if node is an incoming client and the handshake should be initiated. - """ - from neo.Network.NodeLeader import NodeLeader - - self.leader = NodeLeader.Instance() - self.nodeid = self.leader.NodeId - self.remote_nodeid = random.randint(1294967200, 4294967200) - self.endpoint = '' - self.address = None - self.buffer_in = bytearray() - self.myblockrequests = set() - self.bytes_in = 0 - self.bytes_out = 0 - - self.sync_mode = MODE_CATCHUP - - self.host = None - self.port = None - - self.incoming_client = incoming_client - self.handshake_complete = False - self.expect_verack_next = False - self.start_outstanding_data_request = {HEARTBEAT_BLOCKS: 0, HEARTBEAT_HEADERS: 0} - - self.block_loop = None - self.block_loop_deferred = None - - self.peer_loop = None - self.peer_loop_deferred = None - - self.header_loop = None - self.header_loop_deferred = None - - self.disconnect_deferred = None - self.disconnecting = False - - logger.debug(f"{self.prefix} new node created, not yet connected") - - def Disconnect(self, reason=None, isDead=True): - """Close the connection with the remote node client.""" - self.disconnecting = True - self.expect_verack_next = False - if reason: - logger.debug(f"Disconnecting with reason: {reason}") - self.stop_block_loop() - self.stop_header_loop() - self.stop_peerinfo_loop() - if isDead: - self.leader.AddDeadAddress(self.address, reason=f"{self.prefix} Forced disconnect by us") - - self.leader.forced_disconnect_by_us += 1 - - self.disconnect_deferred = defer.Deferred() - self.disconnect_deferred.debug = True - # force disconnection without waiting on the other side - # calling later to give func caller time to add callbacks to the deferred - reactor.callLater(1, self.transport.abortConnection) - return self.disconnect_deferred - - @property - def prefix(self): - if isinstance(self.endpoint, IPv4Address) and self.identifier is not None: - return f"[{self.identifier:03}][{mode_to_name[self.sync_mode]}][{self.address:>21}]" - else: - return f"" - - def Name(self): - """ - Get the peer name. - - Returns: - str: - """ - name = "" - if self.Version: - name = self.Version.UserAgent - return name - - def GetNetworkAddressWithTime(self): - """ - Get a network address object. - - Returns: - NetworkAddressWithTime: if we have a connection to a node. - None: otherwise. - """ - if self.port is not None and self.host is not None and self.Version is not None: - return NetworkAddressWithTime(self.host, self.port, self.Version.Services) - return None - - def IOStats(self): - """ - Get the connection I/O stats. - - Returns: - str: - """ - biM = self.bytes_in / 1000000 # megabyes - boM = self.bytes_out / 1000000 - - return f"{biM:>10} MB in / {boM:>10} MB out" - - def connectionMade(self): - """Callback handler from twisted when establishing a new connection.""" - self.endpoint = self.transport.getPeer() - # get the reference to the Address object in NodeLeader so we can manipulate it properly. - tmp_addr = Address(f"{self.endpoint.host}:{self.endpoint.port}") - try: - known_idx = self.leader.KNOWN_ADDRS.index(tmp_addr) - self.address = self.leader.KNOWN_ADDRS[known_idx] - except ValueError: - # Not found. - self.leader.AddKnownAddress(tmp_addr) - self.address = tmp_addr - - self.address.address = "%s:%s" % (self.endpoint.host, self.endpoint.port) - self.host = self.endpoint.host - self.port = int(self.endpoint.port) - self.leader.AddConnectedPeer(self) - self.leader.RemoveFromQueue(self.address) - self.leader.peers_connecting -= 1 - logger.debug(f"{self.address} connection established") - if self.incoming_client: - # start protocol - self.SendVersion() - - def connectionLost(self, reason=None): - """Callback handler from twisted when a connection was lost.""" - try: - self.connected = False - self.stop_block_loop() - self.stop_peerinfo_loop() - self.stop_header_loop() - - self.ReleaseBlockRequests() - self.leader.RemoveConnectedPeer(self) - - time_expired = self.time_expired(HEARTBEAT_BLOCKS) - # some NEO-cli versions have a 30s timeout to receive block/consensus or tx messages. By default neo-python doesn't respond to these requests - if time_expired > 20: - self.address.last_connection = Address.Now() - self.leader.AddDeadAddress(self.address, reason=f"{self.prefix} Premature disconnect") - - if reason and reason.check(twisted_error.ConnectionDone): - # this might happen if they close our connection because they've reached max peers or something similar - logger.debug(f"{self.prefix} disconnected normally with reason:{reason.value}") - self._check_for_consecutive_disconnects("connection done") - - elif reason and reason.check(twisted_error.ConnectionLost): - # Can be due to a timeout. Only if this happened again within 5 minutes do we label the node as bad - # because then it clearly doesn't want to talk to us or we have a bad connection to them. - # Otherwise allow for the node to be queued again by NodeLeader. - logger.debug(f"{self.prefix} disconnected with connectionlost reason: {reason.value}") - self._check_for_consecutive_disconnects("connection lost") - - else: - logger.debug(f"{self.prefix} disconnected with reason: {reason.value}") - except Exception as e: - logger.error("Error with connection lost: %s " % e) - - def try_me(err): - err.check(error.ConnectionAborted) - - if self.disconnect_deferred: - d, self.disconnect_deferred = self.disconnect_deferred, None # type: defer.Deferred - d.addErrback(try_me) - if len(d.callbacks) > 0: - d.callback(reason) - else: - print("connLost, disconnect_deferred cancelling!") - d.cancel() - - def _check_for_consecutive_disconnects(self, error_name): - now = datetime.datetime.utcnow().timestamp() - FIVE_MINUTES = 5 * 60 - if self.address.last_connection != 0 and now - self.address.last_connection < FIVE_MINUTES: - self.leader.AddDeadAddress(self.address, reason=f"{self.prefix} second {error_name} within 5 minutes") - else: - self.address.last_connection = Address.Now() - - def ReleaseBlockRequests(self): - bcr = BC.Default().BlockRequests - requests = self.myblockrequests - - for req in requests: - try: - if req in bcr: - bcr.remove(req) - except Exception as e: - logger.debug(f"{self.prefix} Could not remove request {e}") - - self.myblockrequests = set() - - def dataReceived(self, data): - """ Called from Twisted whenever data is received. """ - self.bytes_in += (len(data)) - self.buffer_in = self.buffer_in + data - - while self.CheckDataReceived(): - pass - - def CheckDataReceived(self): - """Tries to extract a Message from the data buffer and process it.""" - currentLength = len(self.buffer_in) - if currentLength < 24: - return False - # Extract the message header from the buffer, and return if not enough - # buffer to fully deserialize the message object. - - try: - # Construct message - mstart = self.buffer_in[:24] - ms = StreamManager.GetStream(mstart) - reader = BinaryReader(ms) - m = Message() - - # Extract message metadata - m.Magic = reader.ReadUInt32() - m.Command = reader.ReadFixedString(12).decode('utf-8') - m.Length = reader.ReadUInt32() - m.Checksum = reader.ReadUInt32() - - # Return if not enough buffer to fully deserialize object. - messageExpectedLength = 24 + m.Length - if currentLength < messageExpectedLength: - return False - - except Exception as e: - logger.debug(f"{self.prefix} Error: could not read message header from stream {e}") - # self.Log('Error: Could not read initial bytes %s ' % e) - return False - - finally: - StreamManager.ReleaseStream(ms) - del reader - - # The message header was successfully extracted, and we have enough enough buffer - # to extract the full payload - try: - # Extract message bytes from buffer and truncate buffer - mdata = self.buffer_in[:messageExpectedLength] - self.buffer_in = self.buffer_in[messageExpectedLength:] - - # Deserialize message with payload - stream = StreamManager.GetStream(mdata) - reader = BinaryReader(stream) - message = Message() - message.Deserialize(reader) - - if self.incoming_client and self.expect_verack_next: - if message.Command != 'verack': - self.Disconnect("Expected 'verack' got {}".format(message.Command)) - - # Propagate new message - self.MessageReceived(message) - - except Exception as e: - logger.debug(f"{self.prefix} Could not extract message {e}") - # self.Log('Error: Could not extract message: %s ' % e) - return False - - finally: - StreamManager.ReleaseStream(stream) - - return True - - def MessageReceived(self, m): - """ - Process a message. - - Args: - m (neo.Network.Message): - """ - if m.Command == 'verack': - # only respond with a verack when we connect to another client, not when a client connected to us or - # we might end up in a verack loop - if self.incoming_client: - if self.expect_verack_next: - self.expect_verack_next = False - else: - self.HandleVerack() - elif m.Command == 'version': - self.HandleVersion(m.Payload) - elif m.Command == 'getaddr': - self.SendPeerInfo() - elif m.Command == 'getdata': - self.HandleGetDataMessageReceived(m.Payload) - elif m.Command == 'getblocks': - self.HandleGetBlocksMessageReceived(m.Payload) - elif m.Command == 'inv': - self.HandleInvMessage(m.Payload) - elif m.Command == 'block': - self.HandleBlockReceived(m.Payload) - elif m.Command == 'getheaders': - self.HandleGetHeadersMessageReceived(m.Payload) - elif m.Command == 'headers': - self.HandleBlockHeadersReceived(m.Payload) - elif m.Command == 'addr': - self.HandlePeerInfoReceived(m.Payload) - else: - logger.debug(f"{self.prefix} Command not implemented: {m.Command}") - - def OnLoopError(self, err): - # happens if we cancel the disconnect_deferred before it is executed - # causes no harm - if type(err.value) == CancelledError: - logger_verbose.debug(f"{self.prefix} OnLoopError cancelled deferred") - return - logger.debug(f"{self.prefix} On neo Node loop error {err}") - - def onThreadDeferredErr(self, err): - if type(err.value) == CancelledError: - logger_verbose.debug(f"{self.prefix} onThreadDeferredError cancelled deferred") - return - logger.debug(f"{self.prefix} On Call from thread error {err}") - - def keep_alive(self): - ka = Message("ping") - self.SendSerializedMessage(ka) - - def ProtocolReady(self): - # do not start the looping tasks if we're in the BlockRequests catchup task - # otherwise BCRLen will not drop because the new node will continue adding blocks - logger_verbose.debug(f"{self.prefix} ProtocolReady called") - if not self.leader.check_bcr_loop or (self.leader.check_bcr_loop and not self.leader.check_bcr_loop.running): - logger_verbose.debug(f"{self.prefix} Protocol ready -> starting loops") - self.start_block_loop() - self.start_peerinfo_loop() - self.start_header_loop() - - self.RequestPeerInfo() - - def AskForMoreHeaders(self): - logger.debug(f"{self.prefix} asking for more headers, starting from {BC.Default().HeaderHeight}") - self.health_check(HEARTBEAT_HEADERS) - get_headers_message = Message("getheaders", GetBlocksPayload(hash_start=[BC.Default().CurrentHeaderHash])) - self.SendSerializedMessage(get_headers_message) - - def AskForMoreBlocks(self): - - distance = BC.Default().HeaderHeight - BC.Default().Height - - current_mode = self.sync_mode - - if distance > 2000: - self.sync_mode = MODE_CATCHUP - else: - self.sync_mode = MODE_MAINTAIN - - if self.sync_mode != current_mode: - logger.debug(f"{self.prefix} changing sync_mode to {mode_to_name[self.sync_mode]}") - self.stop_block_loop() - self.start_block_loop() - - else: - if len(BC.Default().BlockRequests) > self.leader.BREQMAX: - logger.debug(f"{self.prefix} data request speed exceeding node response rate...pausing to catch up") - self.leader.throttle_sync() - else: - self.DoAskForMoreBlocks() - - def DoAskForMoreBlocks(self): - hashes = [] - hashstart = BC.Default().Height + 1 - current_header_height = BC.Default().HeaderHeight + 1 - - do_go_ahead = False - if BC.Default().BlockSearchTries > 100 and len(BC.Default().BlockRequests) > 0: - do_go_ahead = True - - first = None - while hashstart <= current_header_height and len(hashes) < self.leader.BREQPART: - hash = BC.Default().GetHeaderHash(hashstart) - if not do_go_ahead: - if hash is not None and hash not in BC.Default().BlockRequests \ - and hash not in self.myblockrequests: - - if not first: - first = hashstart - BC.Default().BlockRequests.add(hash) - self.myblockrequests.add(hash) - hashes.append(hash) - else: - if hash is not None: - if not first: - first = hashstart - BC.Default().BlockRequests.add(hash) - self.myblockrequests.add(hash) - hashes.append(hash) - - hashstart += 1 - - if len(hashes) > 0: - logger.debug( - f"{self.prefix} asking for more blocks {first} - {hashstart} ({len(hashes)}) stale count: {BC.Default().BlockSearchTries} " - f"BCRLen: {len(BC.Default().BlockRequests)}") - self.health_check(HEARTBEAT_BLOCKS) - message = Message("getdata", InvPayload(InventoryType.Block, hashes)) - self.SendSerializedMessage(message) - - def RequestPeerInfo(self): - """Request the peer address information from the remote client.""" - logger.debug(f"{self.prefix} requesting peer info") - self.SendSerializedMessage(Message('getaddr')) - - def HandlePeerInfoReceived(self, payload): - """Process response of `self.RequestPeerInfo`.""" - addrs = IOHelper.AsSerializableWithType(payload, 'neo.Network.Payloads.AddrPayload.AddrPayload') - - if not addrs: - return - - for nawt in addrs.NetworkAddressesWithTime: - self.leader.RemoteNodePeerReceived(nawt.Address, nawt.Port, self.prefix) - - def SendPeerInfo(self): - # if not self.leader.ServiceEnabled: - # return - - peerlist = [] - for peer in self.leader.Peers: - addr = peer.GetNetworkAddressWithTime() - if addr is not None: - peerlist.append(addr) - peer_str_list = list(map(lambda p: p.ToString(), peerlist)) - logger.debug(f"{self.prefix} Sending Peer list {peer_str_list}") - - addrpayload = AddrPayload(addresses=peerlist) - message = Message('addr', addrpayload) - self.SendSerializedMessage(message) - - def RequestVersion(self): - """Request the remote client version.""" - m = Message("getversion") - self.SendSerializedMessage(m) - - def SendVersion(self): - """Send our client version.""" - m = Message("version", VersionPayload(settings.NODE_PORT, self.remote_nodeid, settings.VERSION_NAME)) - self.SendSerializedMessage(m) - - def SendVerack(self): - """Send version acknowledge""" - m = Message('verack') - self.SendSerializedMessage(m) - self.expect_verack_next = True - - def HandleVersion(self, payload): - """Process the response of `self.RequestVersion`.""" - self.Version = IOHelper.AsSerializableWithType(payload, "neo.Network.Payloads.VersionPayload.VersionPayload") - - if not self.Version: - return - - if self.incoming_client: - if self.Version.Nonce == self.nodeid: - self.Disconnect() - self.SendVerack() - else: - self.nodeid = self.Version.Nonce - self.SendVersion() - - def HandleVerack(self): - """Handle the `verack` response.""" - m = Message('verack') - self.SendSerializedMessage(m) - self.leader.NodeCount += 1 - self.identifier = self.leader.NodeCount - logger.debug(f"{self.prefix} Handshake complete!") - self.handshake_complete = True - self.ProtocolReady() - - def HandleInvMessage(self, payload): - """ - Process a block header inventory payload. - - Args: - inventory (neo.Network.Payloads.InvPayload): - """ - - if self.sync_mode != MODE_MAINTAIN: - return - - inventory = IOHelper.AsSerializableWithType(payload, 'neo.Network.Payloads.InvPayload.InvPayload') - if not inventory: - return - - if inventory.Type == InventoryType.BlockInt: - - ok_hashes = [] - for hash in inventory.Hashes: - hash = hash.encode('utf-8') - if hash not in self.myblockrequests and hash not in BC.Default().BlockRequests: - ok_hashes.append(hash) - BC.Default().BlockRequests.add(hash) - self.myblockrequests.add(hash) - if len(ok_hashes): - message = Message("getdata", InvPayload(InventoryType.Block, ok_hashes)) - self.SendSerializedMessage(message) - - elif inventory.Type == InventoryType.TXInt: - pass - elif inventory.Type == InventoryType.ConsensusInt: - pass - - def SendSerializedMessage(self, message): - """ - Send the `message` to the remote client. - - Args: - message (neo.Network.Message): - """ - try: - ba = Helper.ToArray(message) - ba2 = binascii.unhexlify(ba) - self.bytes_out += len(ba2) - self.transport.write(ba2) - except Exception as e: - logger.debug(f"Could not send serialized message {e}") - - def HandleBlockHeadersReceived(self, inventory): - """ - Process a block header inventory payload. - - Args: - inventory (neo.Network.Inventory): - """ - try: - inventory = IOHelper.AsSerializableWithType(inventory, 'neo.Network.Payloads.HeadersPayload.HeadersPayload') - if inventory is not None: - logger.debug(f"{self.prefix} received headers") - self.heart_beat(HEARTBEAT_HEADERS) - BC.Default().AddHeaders(inventory.Headers) - - except Exception as e: - logger.debug(f"Error handling Block headers {e}") - - def HandleBlockReceived(self, inventory): - """ - Process a Block inventory payload. - - Args: - inventory (neo.Network.Inventory): - """ - block = IOHelper.AsSerializableWithType(inventory, 'neo.Core.Block.Block') - if not block: - return - - blockhash = block.Hash.ToBytes() - try: - if blockhash in BC.Default().BlockRequests: - BC.Default().BlockRequests.remove(blockhash) - except KeyError: - pass - try: - if blockhash in self.myblockrequests: - # logger.debug(f"{self.prefix} received block: {block.Index}") - self.heart_beat(HEARTBEAT_BLOCKS) - self.myblockrequests.remove(blockhash) - except KeyError: - pass - self.leader.InventoryReceived(block) - - def time_expired(self, what): - now = datetime.datetime.utcnow().timestamp() - start_time = self.start_outstanding_data_request.get(what) - if start_time == 0: - delta = 0 - else: - delta = now - start_time - return delta - - def health_check(self, what): - # now = datetime.datetime.utcnow().timestamp() - # delta = now - self.start_outstanding_data_request.get(what) - - time_expired = self.time_expired(what) - - if time_expired == 0: - # startup scenario, just go - logger.debug(f"{self.prefix}[HEALTH][{what}] startup or bcr catchup heart_beat") - self.heart_beat(what) - else: - if self.sync_mode == MODE_CATCHUP: - response_threshold = 45 # seconds - else: - response_threshold = 90 # - if time_expired > response_threshold: - header_time = self.time_expired(HEARTBEAT_HEADERS) - header_bad = header_time > response_threshold - block_time = self.time_expired(HEARTBEAT_BLOCKS) - blocks_bad = block_time > response_threshold - if header_bad and blocks_bad: - logger.debug( - f"{self.prefix}[HEALTH] FAILED - No response for Headers {header_time:.2f} and Blocks {block_time:.2f} seconds. Removing node...") - self.Disconnect() - elif blocks_bad and self.leader.check_bcr_loop and self.leader.check_bcr_loop.running: - # when we're in data throttling it is never acceptable if blocks don't come in. - logger.debug( - f"{self.prefix}[HEALTH] FAILED - No Blocks for {block_time:.2f} seconds while throttling. Removing node...") - self.Disconnect() - else: - if header_bad: - logger.debug( - f"{self.prefix}[HEALTH] Headers FAILED @ {header_time:.2f}s, but Blocks OK @ {block_time:.2f}s. Keeping node...") - else: - logger.debug( - f"{self.prefix}[HEALTH] Headers OK @ {header_time:.2f}s, but Blocks FAILED @ {block_time:.2f}s. Keeping node...") - - # logger.debug( - # f"{self.prefix}[HEALTH][{what}] FAILED - No response for {time_expired:.2f} seconds. Removing node...") - - else: - logger.debug(f"{self.prefix}[HEALTH][{what}] OK - response time {time_expired:.2f}") - - def heart_beat(self, what): - self.start_outstanding_data_request[what] = datetime.datetime.utcnow().timestamp() - - def HandleGetHeadersMessageReceived(self, payload): - - if not self.leader.ServiceEnabled: - return - - inventory = IOHelper.AsSerializableWithType(payload, 'neo.Network.Payloads.GetBlocksPayload.GetBlocksPayload') - - if not inventory: - return - - blockchain = BC.Default() - - hash = inventory.HashStart[0] - - if hash is None or hash == inventory.HashStop: - logger.debug("getheaders: Hash {} not found or hashstop reached".format(inventory.HashStart)) - return - - headers = [] - header_count = 0 - - while hash != inventory.HashStop and header_count < 2000: - hash = blockchain.GetNextBlockHash(hash) - if not hash: - break - headers.append(blockchain.GetHeader(hash)) - header_count += 1 - - if header_count > 0: - self.SendSerializedMessage(Message('headers', HeadersPayload(headers=headers))) - - def HandleBlockReset(self, hash): - """Process block reset request.""" - self.myblockrequests = set() - - def HandleGetDataMessageReceived(self, payload): - """ - Process a InvPayload payload. - - Args: - payload (neo.Network.Inventory): - """ - inventory = IOHelper.AsSerializableWithType(payload, 'neo.Network.Payloads.InvPayload.InvPayload') - if not inventory: - return - - for hash in inventory.Hashes: - hash = hash.encode('utf-8') - - item = None - # try to get the inventory to send from relay cache - - if hash in self.leader.RelayCache.keys(): - item = self.leader.RelayCache[hash] - - if inventory.Type == InventoryType.TXInt: - if not item: - item, index = BC.Default().GetTransaction(hash) - if not item: - item = self.leader.GetTransaction(hash) - if item: - message = Message(command='tx', payload=item, print_payload=False) - self.SendSerializedMessage(message) - - elif inventory.Type == InventoryType.BlockInt: - if not item: - item = BC.Default().GetBlock(hash) - if item: - message = Message(command='block', payload=item, print_payload=False) - self.SendSerializedMessage(message) - - elif inventory.Type == InventoryType.ConsensusInt: - if item: - self.SendSerializedMessage(Message(command='consensus', payload=item, print_payload=False)) - - def HandleGetBlocksMessageReceived(self, payload): - """ - Process a GetBlocksPayload payload. - - Args: - payload (neo.Network.Payloads.GetBlocksPayload): - """ - if not self.leader.ServiceEnabled: - return - - inventory = IOHelper.AsSerializableWithType(payload, 'neo.Network.Payloads.GetBlocksPayload.GetBlocksPayload') - if not inventory: - return - - blockchain = BC.Default() - hash = inventory.HashStart[0] - if not blockchain.GetHeader(hash): - return - - hashes = [] - hcount = 0 - while hash != inventory.HashStop and hcount < 500: - hash = blockchain.GetNextBlockHash(hash) - if hash is None: - break - hashes.append(hash) - hcount += 1 - if hcount > 0: - self.SendSerializedMessage(Message('inv', InvPayload(type=InventoryType.Block, hashes=hashes))) - - def Relay(self, inventory): - """ - Wrap the inventory in a InvPayload object and send it over the write to the remote node. - - Args: - inventory: - - Returns: - bool: True (fixed) - """ - inventory = InvPayload(type=inventory.InventoryType, hashes=[inventory.Hash.ToBytes()]) - m = Message("inv", inventory) - self.SendSerializedMessage(m) - - return True - - def __eq__(self, other): - if type(other) is type(self): - return self.address == other.address and self.identifier == other.identifier - else: - return False diff --git a/neo/Network/NodeLeader.py b/neo/Network/NodeLeader.py deleted file mode 100644 index 2332ca0a7..000000000 --- a/neo/Network/NodeLeader.py +++ /dev/null @@ -1,777 +0,0 @@ -import random -import time -from typing import List -from neo.Core.Block import Block -from neo.Core.Blockchain import Blockchain as BC -from neo.Implementations.Blockchains.LevelDB.TestLevelDBBlockchain import TestLevelDBBlockchain -from neo.Core.TX.Transaction import Transaction -from neo.Core.TX.MinerTransaction import MinerTransaction -from neo.Network.NeoNode import NeoNode, HEARTBEAT_BLOCKS -from neo.Settings import settings -from twisted.internet.protocol import ReconnectingClientFactory, Factory -from twisted.internet import error -from twisted.internet import task -from twisted.internet import reactor as twisted_reactor -from twisted.internet.defer import CancelledError, Deferred -from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol, TCP4ServerEndpoint -from neo.logging import log_manager -from neo.Network.address import Address -from neo.Network.Utils import LoopingCall, hostname_to_ip, is_ip_address - -logger = log_manager.getLogger('network') - - -class NodeLeader: - _LEAD = None - - Peers = [] - - KNOWN_ADDRS = [] - DEAD_ADDRS = [] - - NodeId = None - - _MissedBlocks = [] - - BREQPART = 100 - BREQMAX = 10000 - - KnownHashes = [] - MissionsGlobal = [] - MemPool = {} - RelayCache = {} - - NodeCount = 0 - - CurrentBlockheight = 0 - - ServiceEnabled = False - - peer_check_loop = None - peer_check_loop_deferred = None - - check_bcr_loop = None - check_bcr_loop_deferred = None - - memcheck_loop = None - memcheck_loop_deferred = None - - blockheight_loop = None - blockheight_loop_deferred = None - - task_handles = {} - - @staticmethod - def Instance(reactor=None): - """ - Get the local node instance. - - Args: - reactor: (optional) custom reactor to use in NodeLeader. - - Returns: - NodeLeader: instance. - """ - if NodeLeader._LEAD is None: - NodeLeader._LEAD = NodeLeader(reactor) - return NodeLeader._LEAD - - def __init__(self, reactor=None): - """ - Create an instance. - This is the equivalent to C#'s LocalNode.cs - """ - self.Setup() - self.ServiceEnabled = settings.SERVICE_ENABLED - self.peer_zero_count = 0 # track the number of times PeerCheckLoop saw a Peer count of zero. Reset e.g. after 3 times - self.connection_queue = [] - self.reactor = twisted_reactor - self.incoming_server_running = False - self.forced_disconnect_by_us = 0 - self.peers_connecting = 0 - - # for testability - if reactor: - self.reactor = reactor - - def start_peer_check_loop(self): - logger.debug(f"start_peer_check_loop") - if self.peer_check_loop and self.peer_check_loop.running: - logger.debug("start_peer_check_loop: still running -> stopping...") - self.stop_peer_check_loop() - - self.peer_check_loop = LoopingCall(self.PeerCheckLoop, clock=self.reactor) - self.peer_check_loop_deferred = self.peer_check_loop.start(10, now=False) - self.peer_check_loop_deferred.addErrback(self.OnPeerLoopError) - - def stop_peer_check_loop(self, cancel=True): - logger.debug(f"stop_peer_check_loop, cancel: {cancel}") - if self.peer_check_loop and self.peer_check_loop.running: - logger.debug(f"stop_peer_check_loop, calling stop()") - self.peer_check_loop.stop() - if cancel and self.peer_check_loop_deferred: - logger.debug(f"stop_peer_check_loop, calling cancel()") - self.peer_check_loop_deferred.cancel() - - def start_check_bcr_loop(self): - logger.debug(f"start_check_bcr_loop") - if self.check_bcr_loop and self.check_bcr_loop.running: - logger.debug("start_check_bcr_loop: still running -> stopping...") - self.stop_check_bcr_loop() - - self.check_bcr_loop = LoopingCall(self.check_bcr_catchup, clock=self.reactor) - self.check_bcr_loop_deferred = self.check_bcr_loop.start(5) - self.check_bcr_loop_deferred.addErrback(self.OnCheckBcrError) - - def stop_check_bcr_loop(self, cancel=True): - logger.debug(f"stop_check_bcr_loop, cancel: {cancel}") - if self.check_bcr_loop and self.check_bcr_loop.running: - logger.debug(f"stop_check_bcr_loop, calling stop()") - self.check_bcr_loop.stop() - if cancel and self.check_bcr_loop_deferred: - logger.debug(f"stop_check_bcr_loop, calling cancel()") - self.check_bcr_loop_deferred.cancel() - - def start_memcheck_loop(self): - self.stop_memcheck_loop() - self.memcheck_loop = LoopingCall(self.MempoolCheck, clock=self.reactor) - self.memcheck_loop_deferred = self.memcheck_loop.start(240, now=False) - self.memcheck_loop_deferred.addErrback(self.OnMemcheckError) - - def stop_memcheck_loop(self, cancel=True): - if self.memcheck_loop and self.memcheck_loop.running: - self.memcheck_loop.stop() - if cancel and self.memcheck_loop_deferred: - self.memcheck_loop_deferred.cancel() - - def start_blockheight_loop(self): - self.stop_blockheight_loop() - self.CurrentBlockheight = BC.Default().Height - self.blockheight_loop = LoopingCall(self.BlockheightCheck, clock=self.reactor) - self.blockheight_loop_deferred = self.blockheight_loop.start(240, now=False) - self.blockheight_loop_deferred.addErrback(self.OnBlockheightcheckError) - - def stop_blockheight_loop(self, cancel=True): - if self.blockheight_loop and self.blockheight_loop.running: - self.blockheight_loop.stop() - if cancel and self.blockheight_loop_deferred: - self.blockheight_loop_deferred.cancel() - - def Setup(self): - """ - Initialize the local node. - - Returns: - - """ - self.Peers = [] # active nodes that we're connected to - self.KNOWN_ADDRS = [] # node addresses that we've learned about from other nodes - self.DEAD_ADDRS = [] # addresses that were performing poorly or we could not establish a connection to - self.MissionsGlobal = [] - self.NodeId = random.randint(1294967200, 4294967200) - - def Restart(self): - self.stop_peer_check_loop() - self.stop_check_bcr_loop() - self.stop_memcheck_loop() - self.stop_blockheight_loop() - - self.peer_check_loop_deferred = None - self.check_bcr_loop_deferred = None - self.memcheck_loop_deferred = None - self.blockheight_loop_deferred = None - - self.peers_connecting = 0 - - if len(self.Peers) == 0: - # preserve any addresses we know because the peers in the seedlist might have gone bad and then we can't receive new addresses anymore - unique_addresses = list(set(self.KNOWN_ADDRS + self.DEAD_ADDRS)) - self.KNOWN_ADDRS = unique_addresses - self.DEAD_ADDRS = [] - self.peer_zero_count = 0 - self.connection_queue = [] - - self.Start(skip_seeds=True) - - def throttle_sync(self): - for peer in self.Peers: # type: NeoNode - peer.stop_block_loop(cancel=False) - peer.stop_peerinfo_loop(cancel=False) - peer.stop_header_loop(cancel=False) - - # start a loop to check if we've caught up on our requests - if not self.check_bcr_loop: - self.start_check_bcr_loop() - - def check_bcr_catchup(self): - """we're exceeding data request speed vs receive + process""" - logger.debug(f"Checking if BlockRequests has caught up {len(BC.Default().BlockRequests)}") - - # test, perhaps there's some race condition between slow startup and throttle sync, otherwise blocks will never go down - for peer in self.Peers: # type: NeoNode - peer.stop_block_loop(cancel=False) - peer.stop_peerinfo_loop(cancel=False) - peer.stop_header_loop(cancel=False) - - if len(BC.Default().BlockRequests) > 0: - for peer in self.Peers: - peer.keep_alive() - peer.health_check(HEARTBEAT_BLOCKS) - peer_bcr_len = len(peer.myblockrequests) - # if a peer has cleared its queue then reset heartbeat status to avoid timing out when resuming from "check_bcr" if there's 1 or more really slow peer(s) - if peer_bcr_len == 0: - peer.start_outstanding_data_request[HEARTBEAT_BLOCKS] = 0 - - print(f"{peer.prefix} request count: {peer_bcr_len}") - if peer_bcr_len == 1: - next_hash = BC.Default().GetHeaderHash(self.CurrentBlockheight + 1) - print(f"{peer.prefix} {peer.myblockrequests} {next_hash}") - else: - # we're done catching up. Stop own loop and restart peers - self.stop_check_bcr_loop() - self.check_bcr_loop = None - logger.debug("BlockRequests have caught up...resuming sync") - for peer in self.Peers: - peer.ProtocolReady() # this starts all loops again - # give a little bit of time between startup of peers - time.sleep(2) - - def _process_connection_queue(self): - for addr in self.connection_queue: - self.SetupConnection(addr) - - def Start(self, seed_list: List[str] = None, skip_seeds: bool = False) -> None: - """ - Start connecting to the seed list. - - Args: - seed_list: a list of host:port strings if not supplied use list from `protocol.xxx.json` - skip_seeds: skip connecting to seed list - """ - if not seed_list: - seed_list = settings.SEED_LIST - - logger.debug("Starting up nodeleader") - if not skip_seeds: - logger.debug("Attempting to connect to seed list...") - for bootstrap in seed_list: - if not is_ip_address(bootstrap): - host, port = bootstrap.split(':') - bootstrap = f"{hostname_to_ip(host)}:{port}" - addr = Address(bootstrap) - self.KNOWN_ADDRS.append(addr) - self.SetupConnection(addr) - - logger.debug("Starting up nodeleader: starting peer, mempool, and blockheight check loops") - # check in on peers every 10 seconds - self.start_peer_check_loop() - self.start_memcheck_loop() - self.start_blockheight_loop() - - if settings.ACCEPT_INCOMING_PEERS and not self.incoming_server_running: - class OneShotFactory(Factory): - def __init__(self, leader): - self.leader = leader - - def buildProtocol(self, addr): - print(f"building new protocol for addr: {addr}") - self.leader.AddKnownAddress(Address(f"{addr.host}:{addr.port}")) - p = NeoNode(incoming_client=True) - p.factory = self - return p - - def listen_err(err): - print(f"Failed start listening server for reason: {err.value}") - - def listen_ok(value): - self.incoming_server_running = True - - logger.debug(f"Starting up nodeleader: setting up listen server on port: {settings.NODE_PORT}") - server_endpoint = TCP4ServerEndpoint(self.reactor, settings.NODE_PORT) - listenport_deferred = server_endpoint.listen(OneShotFactory(leader=self)) - listenport_deferred.addCallback(listen_ok) - listenport_deferred.addErrback(listen_err) - - def setBlockReqSizeAndMax(self, breqpart=100, breqmax=10000): - if breqpart > 0 and breqpart <= 500 and breqmax > 0 and breqmax > breqpart: - self.BREQPART = breqpart - self.BREQMAX = breqmax - logger.info("Set each node to request %s blocks per request with a total of %s in queue" % (self.BREQPART, self.BREQMAX)) - return True - else: - raise ValueError("invalid values. Please specify a block request part and max size for each node, like 30 and 1000") - - def setBlockReqSizeByName(self, name): - if name.lower() == 'slow': - self.BREQPART = 15 - self.BREQMAX = 5000 - elif name.lower() == 'normal': - self.BREQPART = 100 - self.BREQMAX = 10000 - elif name.lower() == 'fast': - self.BREQPART = 250 - self.BREQMAX = 15000 - else: - logger.info("configuration name %s not found. use 'slow', 'normal', or 'fast'" % name) - return False - - logger.info("Set each node to request %s blocks per request with a total of %s in queue" % (self.BREQPART, self.BREQMAX)) - return True - - def RemoteNodePeerReceived(self, host, port, via_node_addr): - addr = Address("%s:%s" % (host, port)) - if addr not in self.KNOWN_ADDRS and addr not in self.DEAD_ADDRS: - logger.debug(f"Adding new address {addr:>21} to known addresses list, received from {via_node_addr}") - # we always want to save new addresses in case we lose all active connections before we can request a new list - self.KNOWN_ADDRS.append(addr) - - def SetupConnection(self, addr, endpoint=None): - if len(self.Peers) + self.peers_connecting < settings.CONNECTED_PEER_MAX: - try: - host, port = addr.split(':') - if endpoint: - point = endpoint - else: - point = TCP4ClientEndpoint(self.reactor, host, int(port), timeout=5) - self.peers_connecting += 1 - d = connectProtocol(point, NeoNode()) # type: Deferred - d.addErrback(self.clientConnectionFailed, addr) - return d - except Exception as e: - logger.error(f"Setup connection with with {e}") - - def Shutdown(self): - """Disconnect all connected peers.""" - logger.debug("Nodeleader shutting down") - - self.stop_peer_check_loop() - self.peer_check_loop_deferred = None - - self.stop_check_bcr_loop() - self.check_bcr_loop_deferred = None - - self.stop_memcheck_loop() - self.memcheck_loop_deferred = None - - self.stop_blockheight_loop() - self.blockheight_loop_deferred = None - - for p in self.Peers: - p.Disconnect() - - def AddConnectedPeer(self, peer): - """ - Add a new connect peer to the known peers list. - - Args: - peer (NeoNode): instance. - """ - # if present - self.RemoveFromQueue(peer.address) - self.AddKnownAddress(peer.address) - - if len(self.Peers) > settings.CONNECTED_PEER_MAX: - peer.Disconnect("Max connected peers reached", isDead=False) - - if peer not in self.Peers: - self.Peers.append(peer) - else: - # either peer is already in the list and it has reconnected before it timed out on our side - # or it's trying to connect multiple times - # or we hit the max connected peer count - self.RemoveKnownAddress(peer.address) - peer.Disconnect() - - def RemoveConnectedPeer(self, peer): - """ - Remove a connected peer from the known peers list. - - Args: - peer (NeoNode): instance. - """ - if peer in self.Peers: - self.Peers.remove(peer) - - def RemoveFromQueue(self, addr): - """ - Remove an address from the connection queue - Args: - addr: - - Returns: - - """ - if addr in self.connection_queue: - self.connection_queue.remove(addr) - - def RemoveKnownAddress(self, addr): - if addr in self.KNOWN_ADDRS: - self.KNOWN_ADDRS.remove(addr) - - def AddKnownAddress(self, addr): - if addr not in self.KNOWN_ADDRS: - self.KNOWN_ADDRS.append(addr) - - def AddDeadAddress(self, addr, reason=None): - if addr not in self.DEAD_ADDRS: - if reason: - logger.debug(f"Adding address {addr:>21} to DEAD_ADDRS list. Reason: {reason}") - else: - logger.debug(f"Adding address {addr:>21} to DEAD_ADDRS list.") - self.DEAD_ADDRS.append(addr) - - # something in the dead_addrs list cannot be in the known_addrs list. Which holds either "tested and good" or "untested" addresses - self.RemoveKnownAddress(addr) - - def PeerCheckLoop(self): - logger.debug( - f"Peer check loop...checking [A:{len(self.KNOWN_ADDRS)} D:{len(self.DEAD_ADDRS)} C:{len(self.Peers)} M:{settings.CONNECTED_PEER_MAX} " - f"Q:{len(self.connection_queue)}]") - - connected = [] - peer_to_remove = [] - - for peer in self.Peers: - if peer.endpoint == "": - peer_to_remove.append(peer) - else: - connected.append(peer.address) - for p in peer_to_remove: - self.Peers.remove(p) - - self._ensure_peer_tasks_running(connected) - self._check_for_queuing_possibilities(connected) - self._process_connection_queue() - # keep this last, to ensure we first try queueing. - self._monitor_for_zero_connected_peers() - - def _check_for_queuing_possibilities(self, connected): - # we sort addresses such that those that we recently disconnected from are last in the list - self.KNOWN_ADDRS.sort(key=lambda address: address.last_connection) - to_remove = [] - for addr in self.KNOWN_ADDRS: - if addr in self.DEAD_ADDRS: - logger.debug(f"Address {addr} found in DEAD_ADDRS list...skipping") - to_remove.append(addr) - continue - if addr not in connected and addr not in self.connection_queue and len(self.Peers) + len( - self.connection_queue) < settings.CONNECTED_PEER_MAX: - self.connection_queue.append(addr) - logger.debug( - f"Queuing {addr:>21} for new connection [in queue: {len(self.connection_queue)} " - f"connected: {len(self.Peers)} maxpeers:{settings.CONNECTED_PEER_MAX}]") - - # we couldn't remove addresses found in the DEAD_ADDR list from ADDRS while looping over it - # so we do it now to clean up - for addr in to_remove: - # TODO: might be able to remove. Check if this scenario is still possible since the refactor - try: - self.KNOWN_ADDRS.remove(addr) - except KeyError: - pass - - def _monitor_for_zero_connected_peers(self): - """ - Track if we lost connection to all peers. - Give some retries threshold to allow peers that are in the process of connecting or in the queue to be connected to run - - """ - if len(self.Peers) == 0 and len(self.connection_queue) == 0: - if self.peer_zero_count > 2: - logger.debug("Peer count 0 exceeded max retries threshold, restarting...") - self.Restart() - else: - logger.debug( - f"Peer count is 0, allow for retries or queued connections to be established {self.peer_zero_count}") - self.peer_zero_count += 1 - - def _ensure_peer_tasks_running(self, connected): - # double check that the peers that are connected are running their tasks - # unless we're data throttling - # there has been a case where the connection was established, but ProtocolReady() never called nor disconnected. - if not self.check_bcr_loop: - for peer in self.Peers: - if not peer.has_tasks_running() and peer.handshake_complete: - peer.start_all_tasks() - - def InventoryReceived(self, inventory): - """ - Process a received inventory. - - Args: - inventory (neo.Network.Inventory): expect a Block type. - - Returns: - bool: True if processed and verified. False otherwise. - """ - if inventory.Hash.ToBytes() in self._MissedBlocks: - self._MissedBlocks.remove(inventory.Hash.ToBytes()) - - if inventory is MinerTransaction: - return False - - if type(inventory) is Block: - if BC.Default() is None: - return False - - if BC.Default().ContainsBlock(inventory.Index): - return False - - if not BC.Default().AddBlock(inventory): - return False - - else: - if not inventory.Verify(self.MemPool.values()): - return False - - def RelayDirectly(self, inventory): - """ - Relay the inventory to the remote client. - - Args: - inventory (neo.Network.Inventory): - - Returns: - bool: True if relayed successfully. False otherwise. - """ - relayed = False - - self.RelayCache[inventory.Hash.ToBytes()] = inventory - - for peer in self.Peers: - relayed |= peer.Relay(inventory) - - if len(self.Peers) == 0: - if type(BC.Default()) is TestLevelDBBlockchain: - # mock a true result for tests - return True - - logger.info("no connected peers") - - return relayed - - def Relay(self, inventory): - """ - Relay the inventory to the remote client. - - Args: - inventory (neo.Network.Inventory): - - Returns: - bool: True if relayed successfully. False otherwise. - """ - if type(inventory) is MinerTransaction: - return False - - if inventory.Hash.ToBytes() in self.KnownHashes: - return False - - self.KnownHashes.append(inventory.Hash.ToBytes()) - - if type(inventory) is Block: - pass - - elif type(inventory) is Transaction or issubclass(type(inventory), Transaction): - if not self.AddTransaction(inventory): - # if we fail to add the transaction for whatever reason, remove it from the known hashes list or we cannot retry the same transaction again - try: - self.KnownHashes.remove(inventory.Hash.ToBytes()) - except ValueError: - # it not found - pass - return False - else: - # consensus - pass - - relayed = self.RelayDirectly(inventory) - return relayed - - def GetTransaction(self, hash): - if hash in self.MemPool.keys(): - return self.MemPool[hash] - return None - - def AddTransaction(self, tx): - """ - Add a transaction to the memory pool. - - Args: - tx (neo.Core.TX.Transaction): instance. - - Returns: - bool: True if successfully added. False otherwise. - """ - if BC.Default() is None: - return False - - if tx.Hash.ToBytes() in self.MemPool.keys(): - return False - - if BC.Default().ContainsTransaction(tx.Hash): - return False - - if not tx.Verify(self.MemPool.values()): - logger.error("Verifying tx result... failed") - return False - - self.MemPool[tx.Hash.ToBytes()] = tx - - return True - - def RemoveTransaction(self, tx): - """ - Remove a transaction from the memory pool if it is found on the blockchain. - - Args: - tx (neo.Core.TX.Transaction): instance. - - Returns: - bool: True if successfully removed. False otherwise. - """ - if BC.Default() is None: - return False - - if not BC.Default().ContainsTransaction(tx.Hash): - return False - - if tx.Hash.ToBytes() in self.MemPool: - del self.MemPool[tx.Hash.ToBytes()] - return True - - return False - - def MempoolCheck(self): - """ - Checks the Mempool and removes any tx found on the Blockchain - Implemented to resolve https://github.com/CityOfZion/neo-python/issues/703 - """ - txs = [] - values = self.MemPool.values() - for tx in values: - txs.append(tx) - - for tx in txs: - res = self.RemoveTransaction(tx) - if res: - logger.debug("found tx 0x%s on the blockchain ...removed from mempool" % tx.Hash) - - def BlockheightCheck(self): - """ - Checks the current blockheight and finds the peer that prevents advancement - """ - if self.CurrentBlockheight == BC.Default().Height: - if len(self.Peers) > 0: - logger.debug("Blockheight is not advancing ...") - next_hash = BC.Default().GetHeaderHash(self.CurrentBlockheight + 1) - culprit_found = False - for peer in self.Peers: - if next_hash in peer.myblockrequests: - culprit_found = True - peer.Disconnect() - break - - # this happens when we're connecting to other nodes that are stuck themselves - if not culprit_found: - for peer in self.Peers: - peer.Disconnect() - else: - self.CurrentBlockheight = BC.Default().Height - - def clientConnectionFailed(self, err, address: Address): - """ - Called when we fail to connect to an endpoint - Args: - err: Twisted Failure instance - address: the address we failed to connect to - """ - if type(err.value) == error.TimeoutError: - logger.debug(f"Failed connecting to {address} connection timed out") - elif type(err.value) == error.ConnectError: - ce = err.value - if len(ce.args) > 0: - try: - logger.debug(f"Failed connecting to {address} {ce.args[0].value}") - except AttributeError: - if isinstance(ce.args[0], str): - logger.debug(f"Failed connecting to {address} {ce.args[0]}") - else: - logger.debug(f"Failed connecting to {address}") - else: - logger.debug(f"Failed connecting to {address}") - else: - logger.debug(f"Failed connecting to {address} {err.value}") - self.peers_connecting -= 1 - self.RemoveKnownAddress(address) - self.RemoveFromQueue(address) - # if we failed to connect to new addresses, we should always add them to the DEAD_ADDRS list - self.AddDeadAddress(address) - - # for testing - return err.type - - @staticmethod - def Reset(): - NodeLeader._LEAD = None - - NodeLeader.Peers = [] - - NodeLeader.KNOWN_ADDRS = [] - NodeLeader.DEAD_ADDRS = [] - - NodeLeader.NodeId = None - - NodeLeader._MissedBlocks = [] - - NodeLeader.BREQPART = 100 - NodeLeader.BREQMAX = 10000 - - NodeLeader.KnownHashes = [] - NodeLeader.MissionsGlobal = [] - NodeLeader.MemPool = {} - NodeLeader.RelayCache = {} - - NodeLeader.NodeCount = 0 - - NodeLeader.CurrentBlockheight = 0 - - NodeLeader.ServiceEnabled = False - - NodeLeader.peer_check_loop = None - NodeLeader.peer_check_loop_deferred = None - - NodeLeader.check_bcr_loop = None - NodeLeader.check_bcr_loop_deferred = None - - NodeLeader.memcheck_loop = None - NodeLeader.memcheck_loop_deferred = None - - NodeLeader.blockheight_loop = None - NodeLeader.blockheight_loop_deferred = None - - NodeLeader.task_handles = {} - - def OnSetupConnectionErr(self, err): - if type(err.value) == CancelledError: - return - logger.debug("On setup connection error! %s" % err) - - def OnCheckBcrError(self, err): - if type(err.value) == CancelledError: - return - logger.debug("On Check BlockRequest error! %s" % err) - - def OnPeerLoopError(self, err): - if type(err.value) == CancelledError: - return - logger.debug("Error on Peer check loop %s " % err) - - def OnMemcheckError(self, err): - if type(err.value) == CancelledError: - return - logger.debug("Error on Memcheck check %s " % err) - - def OnBlockheightcheckError(self, err): - if type(err.value) == CancelledError: - return - logger.debug("Error on Blockheight check loop %s " % err) diff --git a/neo/Network/Payloads/AddrPayload.py b/neo/Network/Payloads/AddrPayload.py deleted file mode 100644 index c1d61c34c..000000000 --- a/neo/Network/Payloads/AddrPayload.py +++ /dev/null @@ -1,46 +0,0 @@ -from neo.Core.IO.Mixins import SerializableMixin -import sys -from neo.Core.Size import GetVarSize - - -class AddrPayload(SerializableMixin): - NetworkAddressesWithTime = [] - - def __init__(self, addresses=None): - """ - Create an instance. - - Args: - addresses (list): of neo.Network.Payloads.NetworkAddressWithTime.NetworkAddressWithTime instances. - """ - self.NetworkAddressesWithTime = addresses if addresses else [] - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - return GetVarSize(self.NetworkAddressesWithTime) - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.NetworkAddressesWithTime = reader.ReadSerializableArray( - 'neo.Network.Payloads.NetworkAddressWithTime.NetworkAddressWithTime') - - def Serialize(self, writer): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteVarInt(len(self.NetworkAddressesWithTime)) - for address in self.NetworkAddressesWithTime: - address.Serialize(writer) diff --git a/neo/Network/Payloads/ConsensusPayload.py b/neo/Network/Payloads/ConsensusPayload.py deleted file mode 100644 index 9a88809af..000000000 --- a/neo/Network/Payloads/ConsensusPayload.py +++ /dev/null @@ -1,52 +0,0 @@ -from neo.Core.IO.Mixins import SerializableMixin -from neo.Core.Cryptography.Helper import bin_dbl_sha256 -from neo.Core.Helper import Helper -from neo.Network.InventoryType import InventoryType -from neo.Core.Size import Size as s -from neo.Core.Size import GetVarSize - - -class ConsensusPayload(SerializableMixin): - InventoryType = InventoryType.Consensus - Version = None - PrevHash = None - BlockIndex = None - ValidatorIndex = None - Timestamp = None - Data = [] - Witness = None - - _hash = None - - def Hash(self): - if not self._hash: - self._hash = bin_dbl_sha256(Helper.GetHashData(self)) - return self._hash - - def Size(self): - scriptsize = 0 - if self.Script is not None: - scriptsize = self.Script.Size() - - return s.uint32 + s.uint256 + s.uint32 + s.uint16 + s.uint32 + GetVarSize(self.Data) + 1 + scriptsize - - def GetMessage(self): - return Helper.GetHashData(self) - - def GetScriptHashesForVerifying(self): - raise NotImplementedError() - - def Deserialize(self, reader): - raise NotImplementedError('Consensus not implemented') - - def DeserializeUnsigned(self, reader): - raise NotImplementedError() - - def Serialize(self, writer): - raise NotImplementedError() - - def SerializeUnsigned(self, writer): - raise NotImplementedError() - - def Verify(self): - raise NotImplementedError() diff --git a/neo/Network/Payloads/GetBlocksPayload.py b/neo/Network/Payloads/GetBlocksPayload.py deleted file mode 100644 index e628a4d1e..000000000 --- a/neo/Network/Payloads/GetBlocksPayload.py +++ /dev/null @@ -1,52 +0,0 @@ -import sys -import binascii -from neo.Core.IO.Mixins import SerializableMixin -from neo.Core.UInt256 import UInt256 -from neo.Core.Size import GetVarSize - - -class GetBlocksPayload(SerializableMixin): - HashStart = [] - HashStop = None - - def __init__(self, hash_start=[], hash_stop=UInt256()): - """ - Create an instance. - - Args: - hash_start (list): a list of hash values. Each value is of the bytearray type. Note: should actually be UInt256 objects. - hash_stop (UInt256): - """ - self.HashStart = hash_start - self.HashStop = hash_stop - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - corrected_hashes = list(map(lambda i: UInt256(data=binascii.unhexlify(i)), self.HashStart)) - return GetVarSize(corrected_hashes) + self.hash_stop.Size - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.HashStart = reader.ReadSerializableArray('neo.Core.UInt256.UInt256') - self.HashStop = reader.ReadUInt256() - - def Serialize(self, writer): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteHashes(self.HashStart) - if self.HashStop is not None: - writer.WriteUInt256(self.HashStop) diff --git a/neo/Network/Payloads/HeadersPayload.py b/neo/Network/Payloads/HeadersPayload.py deleted file mode 100644 index 2f9d27ad1..000000000 --- a/neo/Network/Payloads/HeadersPayload.py +++ /dev/null @@ -1,44 +0,0 @@ -from neo.Core.IO.Mixins import SerializableMixin -import sys -from neo.Core.Size import GetVarSize -from neo.Core.IO.BinaryWriter import BinaryWriter - - -class HeadersPayload(SerializableMixin): - Headers = [] - - def __init__(self, headers=None): - """ - Create an instance. - - Args: - headers (list): of neo.Core.Header.Header objects. - """ - self.Headers = headers if headers else [] - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - return GetVarSize(self.Headers) - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.Headers = reader.ReadSerializableArray('neo.Core.Header.Header') - - def Serialize(self, writer: BinaryWriter): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteSerializableArray(self.Headers) diff --git a/neo/Network/Payloads/InvPayload.py b/neo/Network/Payloads/InvPayload.py deleted file mode 100644 index 261897289..000000000 --- a/neo/Network/Payloads/InvPayload.py +++ /dev/null @@ -1,71 +0,0 @@ -import binascii -from neo.Core.UInt256 import UInt256 -from neo.Core.IO.Mixins import SerializableMixin -from neo.Core.Size import Size as s -from neo.Core.Size import GetVarSize -from neo.logging import log_manager - -logger = log_manager.getLogger() - - -class InvPayload(SerializableMixin): - Type = None - Hashes = [] - - def __init__(self, type=None, hashes=None): - """ - Create an instance. - - Args: - type (neo.Network.InventoryType): - hashes (list): of bytearray items. - """ - self.Type = type - self.Hashes = hashes if hashes else [] - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - if len(self.Hashes) > 0: - if not isinstance(self.Hashes[0], UInt256): - corrected_hashes = list(map(lambda i: UInt256(data=binascii.unhexlify(i)), self.Hashes)) - return s.uint8 + GetVarSize(corrected_hashes) - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.Type = ord(reader.ReadByte()) - self.Hashes = reader.ReadHashes() - - def Serialize(self, writer): - """ - Serialize object. - - Raises: - Exception: if hash writing fails. - - Args: - writer (neo.IO.BinaryWriter): - """ - try: - writer.WriteByte(self.Type) - writer.WriteHashes(self.Hashes) - except Exception as e: - logger.error(f"COULD NOT WRITE INVENTORY HASHES ({self.Type} {self.Hashes}) {e}") - - def ToString(self): - """ - Get the string representation of the payload. - - Returns: - str: - """ - return "INVENTORY Type %s hashes %s " % (self.Type, [h for h in self.Hashes]) diff --git a/neo/Network/Payloads/NetworkAddressWithTime.py b/neo/Network/Payloads/NetworkAddressWithTime.py deleted file mode 100644 index 640387702..000000000 --- a/neo/Network/Payloads/NetworkAddressWithTime.py +++ /dev/null @@ -1,83 +0,0 @@ -import ctypes -from datetime import datetime -from neo.Core.IO.Mixins import SerializableMixin -from neo.Core.Size import Size as s - - -class NetworkAddressWithTime(SerializableMixin): - NODE_NETWORK = 1 - - Timestamp = None - Services = None - Address = None - Port = None - - def __init__(self, address=None, port=None, services=0, timestamp=int(datetime.utcnow().timestamp())): - """ - Create an instance. - - Args: - address (str): - port (int): - services (int): - timestamp (int): - """ - self.Address = address - self.Port = port - self.Services = services - self.Timestamp = timestamp - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - return s.uint32 + s.uint64 + 16 + s.uint16 - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.Timestamp = reader.ReadUInt32() - self.Services = reader.ReadUInt64() - addr = bytearray(reader.ReadFixedString(16)) - addr.reverse() - addr.strip(b'\x00') - nums = [] - for i in range(0, 4): - nums.append(str(addr[i])) - nums.reverse() - adddd = '.'.join(nums) - self.Address = adddd - self.Port = reader.ReadUInt16(endian='>') - - def Serialize(self, writer): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteUInt32(self.Timestamp) - writer.WriteUInt64(self.Services) - # turn ip address into bytes - octets = bytearray(map(lambda oct: int(oct), self.Address.split('.'))) - # pad to fixed length 16 - octets += bytearray(12) - # and finally write to stream - writer.WriteBytes(octets) - writer.WriteUInt16(self.Port, endian='>') - - def ToString(self): - """ - Get the string representation of the network address. - - Returns: - str: address:port - """ - return '%s:%s' % (self.Address, self.Port) diff --git a/neo/Network/Payloads/VersionPayload.py b/neo/Network/Payloads/VersionPayload.py deleted file mode 100644 index fe36e289f..000000000 --- a/neo/Network/Payloads/VersionPayload.py +++ /dev/null @@ -1,84 +0,0 @@ -import datetime -from neo.Core.IO.Mixins import SerializableMixin -from neo.Network.Payloads.NetworkAddressWithTime import NetworkAddressWithTime -from neo.Core.Blockchain import Blockchain -from neo.Core.Size import Size as s -from neo.Core.Size import GetVarSize -from neo.logging import log_manager - -logger = log_manager.getLogger() - - -class VersionPayload(SerializableMixin): - Version = None - Services = None - Timestamp = None - Port = None - Nonce = None - UserAgent = None - StartHeight = 0 - Relay = False - - def __init__(self, port=None, nonce=None, userAgent=None): - """ - Create an instance. - - Args: - port (int): - nonce (int): - userAgent (str): client user agent string. - """ - if port and nonce and userAgent: - self.Port = port - self.Version = 0 - self.Services = NetworkAddressWithTime.NODE_NETWORK - self.Timestamp = int(datetime.datetime.utcnow().timestamp()) - self.Nonce = nonce - self.UserAgent = userAgent - - if Blockchain.Default() is not None and Blockchain.Default().Height is not None: - self.StartHeight = Blockchain.Default().Height - - self.Relay = True - - def Size(self): - """ - Get the total size in bytes of the object. - - Returns: - int: size. - """ - return s.uint32 + s.uint64 + s.uint32 + s.uint16 + s.uint32 + GetVarSize(self.UserAgent) + s.uint32 + s.uint8 - - def Deserialize(self, reader): - """ - Deserialize full object. - - Args: - reader (neo.IO.BinaryReader): - """ - self.Version = reader.ReadUInt32() - self.Services = reader.ReadUInt64() - self.Timestamp = reader.ReadUInt32() - self.Port = reader.ReadUInt16() - self.Nonce = reader.ReadUInt32() - self.UserAgent = reader.ReadVarString().decode('utf-8') - self.StartHeight = reader.ReadUInt32() - logger.debug("Version start height: T %s " % self.StartHeight) - self.Relay = reader.ReadBool() - - def Serialize(self, writer): - """ - Serialize object. - - Args: - writer (neo.IO.BinaryWriter): - """ - writer.WriteUInt32(self.Version) - writer.WriteUInt64(self.Services) - writer.WriteUInt32(self.Timestamp) - writer.WriteUInt16(self.Port) - writer.WriteUInt32(self.Nonce) - writer.WriteVarString(self.UserAgent) - writer.WriteUInt32(self.StartHeight) - writer.WriteBool(self.Relay) diff --git a/neo/Network/Payloads/test_payloads.py b/neo/Network/Payloads/test_payloads.py deleted file mode 100644 index 56c7a78fa..000000000 --- a/neo/Network/Payloads/test_payloads.py +++ /dev/null @@ -1,120 +0,0 @@ -import random -import binascii -from datetime import datetime - -from neo.Utils.NeoTestCase import NeoTestCase -from neo.Network.Payloads.VersionPayload import VersionPayload -from neo.Network.Payloads.NetworkAddressWithTime import NetworkAddressWithTime -from neo.Network.Message import Message -from neo.IO.Helper import Helper as IOHelper -from neo.Core.IO.BinaryWriter import BinaryWriter -from neo.Core.IO.BinaryReader import BinaryReader -from neo.IO.MemoryStream import StreamManager -from neo.Settings import settings -from neo.Core.Helper import Helper - - -class PayloadTestCase(NeoTestCase): - - port = 20333 - nonce = random.randint(12949672, 42949672) - ua = "/NEO:2.4.1/" - - payload = None - - def setUp(self): - - self.payload = VersionPayload(self.port, self.nonce, self.ua) - - def test_version_create(self): - - self.assertEqual(self.payload.Nonce, self.nonce) - self.assertEqual(self.payload.Port, self.port) - self.assertEqual(self.payload.UserAgent, self.ua) - - def test_version_serialization(self): - - serialized = binascii.unhexlify(Helper.ToArray(self.payload)) - - deserialized_version = IOHelper.AsSerializableWithType(serialized, 'neo.Network.Payloads.VersionPayload.VersionPayload') - - v = deserialized_version - self.assertEqual(v.Nonce, self.nonce) - self.assertEqual(v.Port, self.port) - self.assertEqual(v.UserAgent, self.ua) - self.assertEqual(v.Timestamp, self.payload.Timestamp) - self.assertEqual(v.StartHeight, self.payload.StartHeight) - self.assertEqual(v.Version, self.payload.Version) - self.assertEqual(v.Services, self.payload.Services) - self.assertEqual(v.Relay, self.payload.Relay) - - def test_message_serialization(self): - - message = Message('version', payload=self.payload) - - self.assertEqual(message.Command, 'version') - - ms = StreamManager.GetStream() - writer = BinaryWriter(ms) - - message.Serialize(writer) - - result = binascii.unhexlify(ms.ToArray()) - StreamManager.ReleaseStream(ms) - - ms = StreamManager.GetStream(result) - reader = BinaryReader(ms) - - deserialized_message = Message() - deserialized_message.Deserialize(reader) - - StreamManager.ReleaseStream(ms) - - dm = deserialized_message - - self.assertEqual(dm.Command, 'version') - - self.assertEqual(dm.Magic, settings.MAGIC) - - checksum = Message.GetChecksum(dm.Payload) - - self.assertEqual(checksum, dm.Checksum) - - deserialized_version = IOHelper.AsSerializableWithType(dm.Payload, 'neo.Network.Payloads.VersionPayload.VersionPayload') - - self.assertEqual(deserialized_version.Port, self.port) - self.assertEqual(deserialized_version.UserAgent, self.ua) - - self.assertEqual(deserialized_version.Timestamp, self.payload.Timestamp) - - def test_network_addrtime(self): - - addr = "55.15.69.104" - port = 10333 - ts = int(datetime.now().timestamp()) - services = 0 - - nawt = NetworkAddressWithTime(addr, port, services, ts) - - ms = StreamManager.GetStream() - writer = BinaryWriter(ms) - - nawt.Serialize(writer) - - arr = ms.ToArray() - arhex = binascii.unhexlify(arr) - - StreamManager.ReleaseStream(ms) - - ms = StreamManager.GetStream(arhex) - reader = BinaryReader(ms) - - nawt2 = NetworkAddressWithTime() - nawt2.Deserialize(reader) - - StreamManager.ReleaseStream(ms) - -# self.assertEqual(nawt.Address, nawt2.Address) - self.assertEqual(nawt.Services, nawt2.Services) - self.assertEqual(nawt.Port, nawt2.Port) - self.assertEqual(nawt.Timestamp, nawt2.Timestamp) diff --git a/neo/Network/Utils.py b/neo/Network/Utils.py deleted file mode 100644 index 393ddab04..000000000 --- a/neo/Network/Utils.py +++ /dev/null @@ -1,60 +0,0 @@ -from twisted.internet import task, interfaces, defer -from zope.interface import implementer -from twisted.test import proto_helpers -from twisted.internet.endpoints import _WrappingFactory -import socket -import ipaddress - - -class LoopingCall(task.LoopingCall): - """ - A testable looping call - """ - - def __init__(self, *a, **kw): - if 'clock' in kw: - clock = kw['clock'] - del kw['clock'] - super(LoopingCall, self).__init__(*a, **kw) - - self.clock = clock - - -@implementer(interfaces.IStreamClientEndpoint) -class TestTransportEndpoint(object): - """ - Helper class for testing - """ - - def __init__(self, reactor, addr, tr=None): - self.reactor = reactor - self.addr = addr - self.tr = proto_helpers.StringTransport() - if tr: - self.tr = tr - - def connect(self, protocolFactory): - """ - Implement L{IStreamClientEndpoint.connect} to connect via StringTransport. - """ - try: - node = protocolFactory.buildProtocol((self.addr)) - node.makeConnection(self.tr) - # because the Twisted `StringTransportWithDisconnection` helper class tries to weirdly enough access `protocol` on a transport - self.tr.protocol = node - return defer.succeed(node) - except Exception: - return defer.fail() - - -def hostname_to_ip(hostname): - return socket.gethostbyname(hostname) - - -def is_ip_address(hostname): - host = hostname.split(':')[0] - try: - ip = ipaddress.ip_address(host) - return True - except ValueError: - return False diff --git a/neo/Network/address.py b/neo/Network/address.py deleted file mode 100755 index f52100be5..000000000 --- a/neo/Network/address.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime - - -class Address: - def __init__(self, address: str, last_connection_to: float = None): - """ - Initialize - Args: - address: a host:port - last_connection_to: timestamp since we were last connected. Default's to 0 indicating 'never' - """ - if not last_connection_to: - self.last_connection = 0 - else: - self.last_connection = last_connection_to - - self.address = address # type: str - - @classmethod - def Now(cls): - return datetime.datetime.utcnow().timestamp() - - def __eq__(self, other): - if type(other) is type(self): - return self.address == other.address - else: - return False - - def __repr__(self): - return f"<{self.__class__.__name__} at {hex(id(self))}> {self.address} ({self.last_connection:.2f})" - - def __str__(self): - return self.address - - def __call__(self, *args, **kwargs): - return self.address - - def __hash__(self): - return hash((self.address, self.last_connection)) - - def __format__(self, format_spec): - return self.address.__format__(format_spec) - - def split(self, on): - return self.address.split(on) - - def rsplit(self, on, maxsplit): - return self.address.rsplit(on, maxsplit) diff --git a/neo/Network/common/__init__.py b/neo/Network/common/__init__.py new file mode 100755 index 000000000..5ebcdbf6a --- /dev/null +++ b/neo/Network/common/__init__.py @@ -0,0 +1,81 @@ +import asyncio +import string +from neo.Network.common.events import Events +from contextlib import contextmanager + +from prompt_toolkit.eventloop import set_event_loop as prompt_toolkit_set_event_loop +from prompt_toolkit.eventloop import create_asyncio_event_loop as prompt_toolkit_create_async_event_loop +from prompt_toolkit import prompt + +msgrouter = Events() + + +def wait_for(coro): + with get_event_loop() as loop: + return loop.run_until_complete(coro) + + +def blocking_prompt(text, **kwargs): + with get_event_loop() as loop: + return loop.run_until_complete(prompt(text, async_=True, **kwargs)) + + +class LoopPool: + def __init__(self): + self.loops = set() + + def borrow_loop(self): + try: + return self.loops.pop() + except KeyError: + return asyncio.new_event_loop() + + def return_loop(self, loop): + # loop.stop() + self.loops.add(loop) + + +loop_pool = LoopPool() + + +@contextmanager +def get_event_loop(): + loop = asyncio.get_event_loop() + if not loop.is_running(): + yield loop + else: + new_loop = loop_pool.borrow_loop() + asyncio.set_event_loop(new_loop) + prompt_loop = loop_pool.borrow_loop() + new_prompt_loop = prompt_toolkit_create_async_event_loop(new_loop) + prompt_toolkit_set_event_loop(new_prompt_loop) + running_loop = asyncio.events._get_running_loop() + asyncio.events._set_running_loop(None) + try: + yield new_loop + finally: + loop_pool.return_loop(new_loop) + loop_pool.return_loop(prompt_loop) + asyncio.set_event_loop(loop) + prompt_toolkit_set_event_loop(prompt_toolkit_create_async_event_loop(loop)) + asyncio.events._set_running_loop(running_loop) + + +chars = string.digits + string.ascii_letters +base = len(chars) + + +def encode_base62(num: int): + """Encode number in base62, returns a string.""" + if num < 0: + raise ValueError('cannot encode negative numbers') + + if num == 0: + return chars[0] + + digits = [] + while num: + rem = num % base + num = num // base + digits.append(chars[rem]) + return ''.join(reversed(digits)) diff --git a/neo/Network/common/events.py b/neo/Network/common/events.py new file mode 100755 index 000000000..7d0cb5b85 --- /dev/null +++ b/neo/Network/common/events.py @@ -0,0 +1,128 @@ +import asyncio + +""" + Events + ~~~~~~ + + Implements C#-Style Events. + + Derived from the original work by Zoran Isailovski: + http://code.activestate.com/recipes/410686/ - Copyright (c) 2005 + + :copyright: (c) 2014-2017 by Nicola Iarocci. + :license: BSD, see LICENSE for more details. + + Expanded to support async event calling by Erik van den Brink +""" + + +class EventsException(Exception): + pass + + +class Events: + """ + Encapsulates the core to event subscription and event firing, and feels + like a "natural" part of the language. + + The class Events is there mainly for 3 reasons: + + - Events (Slots) are added automatically, so there is no need to + declare/create them separately. This is great for prototyping. (Note + that `__events__` is optional and should primarilly help detect + misspelled event names.) + - To provide (and encapsulate) some level of introspection. + - To "steel the name" and hereby remove unneeded redundancy in a call + like: + + xxx.OnChange = event('OnChange') + """ + + def __init__(self, events=None): + + if events is not None: + + try: + for _ in events: + break + except Exception: + raise AttributeError("type object %s is not iterable" % + (type(events))) + else: + self.__events__ = events + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError("type object '%s' has no attribute '%s'" % + (self.__class__.__name__, name)) + + if hasattr(self, '__events__'): + if name not in self.__events__: + raise EventsException("Event '%s' is not declared" % name) + + elif hasattr(self.__class__, '__events__'): + if name not in self.__class__.__events__: + raise EventsException("Event '%s' is not declared" % name) + + self.__dict__[name] = ev = _EventSlot(name) + return ev + + def __repr__(self): + return '<%s.%s object at %s>' % (self.__class__.__module__, + self.__class__.__name__, + hex(id(self))) + + __str__ = __repr__ + + def __len__(self): + return len(self.__dict__.items()) + + def __iter__(self): + def gen(dictitems=self.__dict__.items()): + for attr, val in dictitems: + if isinstance(val, _EventSlot): + yield val + + return gen() + + +class _EventSlot: + def __init__(self, name): + self.targets = [] + self.__name__ = name + + def __repr__(self): + return "event '%s'" % self.__name__ + + def __call__(self, *a, **kw): + tasks = [] + for f in tuple(self.targets): + if asyncio.coroutines.iscoroutinefunction(f): + tasks.append(asyncio.create_task(f(*a, **kw))) + else: + f(*a, **kw) + + if len(tasks) > 0: + return asyncio.gather(*tasks) + + def __iadd__(self, f): + self.targets.append(f) + return self + + def __isub__(self, f): + while f in self.targets: + self.targets.remove(f) + return self + + def __len__(self): + return len(self.targets) + + def __iter__(self): + def gen(): + for target in self.targets: + yield target + + return gen() + + def __getitem__(self, key): + return self.targets[key] diff --git a/neo/Network/common/singleton.py b/neo/Network/common/singleton.py new file mode 100755 index 000000000..20fac4390 --- /dev/null +++ b/neo/Network/common/singleton.py @@ -0,0 +1,19 @@ +""" +Courtesy of Guido: https://www.python.org/download/releases/2.2/descrintro/#__new__ + +To create a singleton class, you subclass from Singleton; each subclass will have a single instance, no matter how many times its constructor is called. +To further initialize the subclass instance, subclasses should override 'init' instead of __init__ - the __init__ method is called each time the constructor is called. +""" + + +class Singleton(object): + def __new__(cls, *args, **kwds): + it = cls.__dict__.get("__it__") + if it is not None: + return it + cls.__it__ = it = object.__new__(cls) + it.init(*args, **kwds) + return it + + def init(self, *args, **kwds): + pass diff --git a/neo/Network/Payloads/__init__.py b/neo/Network/core/__init__.py similarity index 100% rename from neo/Network/Payloads/__init__.py rename to neo/Network/core/__init__.py diff --git a/neo/Network/core/blockbase.py b/neo/Network/core/blockbase.py new file mode 100644 index 000000000..54ff3c5c9 --- /dev/null +++ b/neo/Network/core/blockbase.py @@ -0,0 +1,78 @@ +import hashlib +from neo.Network.core.exceptions import DeserializationError +from neo.Network.core.mixin.serializable import SerializableMixin +from neo.Network.core.uint256 import UInt256 +from neo.Network.core.uint160 import UInt160 +from neo.Network.core.io.binary_reader import BinaryReader +from neo.Network.core.io.binary_writer import BinaryWriter + + +class BlockBase(SerializableMixin): + + def __init__(self, version: int, prev_hash: UInt256, merkle_root: UInt256, timestamp: int, index: int, consensus_data, next_consensus: UInt160, witness): + self.version = version + self.prev_hash = prev_hash + self.merkle_root = merkle_root + self.timestamp = timestamp + self.index = index + self.consensus_data = consensus_data + self.next_consensus = next_consensus + self.witness = bytearray() # witness + + @property + def hash(self): + writer = BinaryWriter(stream=bytearray()) + self.serialize_unsigned(writer) + hash_data = writer._stream.getvalue() + hash = hashlib.sha256(hashlib.sha256(hash_data).digest()).digest() + writer.cleanup() + return UInt256(data=hash) + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + self.serialize_unsigned(writer) + + writer.write_uint8(1) + # TODO: Normally we should write a Witness object + # we did not implement this at this moment because we don't need this data. + # writer.write_var_bytes(self.witness) + # so instead we just write 0 length indicators for the 2 members of script + writer.write_var_int(0) # invocation script length + writer.write_var_int(0) # verification script length + + def serialize_unsigned(self, writer: 'BinaryWriter') -> None: + """ Serialize unsigned object data only. """ + writer.write_uint32(self.version) + writer.write_uint256(self.prev_hash) + writer.write_uint256(self.merkle_root) + writer.write_uint32(self.timestamp) + writer.write_uint32(self.index) + writer.write_uint64(self.consensus_data) + writer.write_uint160(self.next_consensus) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self.version = reader.read_uint32() + self.prev_hash = reader.read_uint256() + self.merkle_root = reader.read_uint256() + self.timestamp = reader.read_uint32() + self.index = reader.read_uint32() + self.consensus_data = reader.read_uint64() + self.next_consensus = reader.read_uint160() + + val = reader.read_byte() + if int(val.hex()) != 1: + raise DeserializationError(f"expected 1 got {val}") + + # TODO: self.witness = reader.read(Witness()) + # witness consists of InvocationScript + VerificationScript + # instead of a full implementation we just have a bytearray as we don't need the data + raw_witness = reader.read_var_bytes() # invocation script + raw_witness += reader.read_var_bytes() # verification script + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/core/exceptions.py b/neo/Network/core/exceptions.py new file mode 100644 index 000000000..50f60fd5d --- /dev/null +++ b/neo/Network/core/exceptions.py @@ -0,0 +1,2 @@ +class DeserializationError(Exception): + pass diff --git a/neo/Network/core/header.py b/neo/Network/core/header.py new file mode 100644 index 000000000..2baa7c3c5 --- /dev/null +++ b/neo/Network/core/header.py @@ -0,0 +1,59 @@ +from neo.Network.core.blockbase import BlockBase +from neo.Network.core.exceptions import DeserializationError +from neo.Network.core.uint256 import UInt256 +from typing import Union +from neo.Network.core.io.binary_reader import BinaryReader +from neo.Network.core.io.binary_writer import BinaryWriter + + +class Header(BlockBase): + def __init__(self, prev_hash, merkle_root, timestamp, index, consensus_data, next_consensus, witness): + version = 0 + temp_merkeroot = UInt256.zero() + super(Header, self).__init__(version, prev_hash, temp_merkeroot, timestamp, index, consensus_data, next_consensus, witness) + + self.prev_hash = prev_hash + self.merkle_root = merkle_root + self.timestamp = timestamp + self.index = index + self.consensus_data = consensus_data + self.next_consensus = next_consensus + self.witness = bytearray() # witness + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + super(Header, self).serialize(writer) + writer.write_uint8(0) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object + + Raises: + DeserializationError: if insufficient or incorrect data + """ + super(Header, self).deserialize(reader) + try: + val = reader.read_byte() + if int(val.hex()) != 0: + raise DeserializationError(f"expected 0 got {val}") + except ValueError as ve: + raise DeserializationError(str(ve)) + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]) -> 'Header': + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + header = cls(None, None, None, None, None, None, None) + try: + header.deserialize(br) + except DeserializationError: + return None + br.cleanup() + return header + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/from_scratch/__init__.py b/neo/Network/core/io/__init__.py similarity index 100% rename from neo/Network/from_scratch/__init__.py rename to neo/Network/core/io/__init__.py diff --git a/neo/Network/core/io/binary_reader.py b/neo/Network/core/io/binary_reader.py new file mode 100644 index 000000000..10be44f98 --- /dev/null +++ b/neo/Network/core/io/binary_reader.py @@ -0,0 +1,226 @@ +import sys +import struct +from typing import Union, Any +from neo.Network.core.uint256 import UInt256 +from neo.Network.core.uint160 import UInt160 +from neo.IO.MemoryStream import StreamManager + + +class BinaryReader(object): + """A convenience class for reading data from byte streams""" + + def __init__(self, stream: Union[bytes, bytearray]) -> None: + """ + Create an instance. + + Args: + stream (BytesIO, bytearray): a stream to operate on. + """ + super(BinaryReader, self).__init__() + self._stream = StreamManager.GetStream(stream) + + def _unpack(self, fmt, length=1) -> Any: + """ + Unpack the stream contents according to the specified format in `fmt`. + For more information about the `fmt` format see: https://docs.python.org/3/library/struct.html + + Args: + fmt (str): format string. + length (int): amount of bytes to read. + + Returns: + variable: the result according to the specified format. + """ + try: + values = struct.unpack(fmt, self._stream.read(length)) + return values[0] + except struct.error as e: + raise ValueError(e) + + def read_byte(self) -> bytes: + """ + Read a single byte. + + Raises: + ValueError: if 1 byte of data cannot be read from the stream + + Returns: + bytes: a single byte. + """ + value = self._stream.read(1) + if len(value) != 1: + raise ValueError("Could not read byte from empty stream") + return value + + def read_bytes(self, length: int) -> bytes: + """ + Read the specified number of bytes from the stream. + + Args: + length (int): number of bytes to read. + + Returns: + bytes: `length` number of bytes. + """ + value = self._stream.read(length) + if len(value) != length: + raise ValueError("Could not read {} bytes from stream. Only found {} bytes of data".format(length, len(value))) + + return value + + def read_bool(self) -> bool: + """ + Read 1 byte as a boolean value from the stream. + + Returns: + bool: + """ + return self._unpack('?') + + def read_uint8(self, endian="<"): + """ + Read 1 byte as an unsigned integer value from the stream. + + Args: + endian (str): specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + + Returns: + int: + """ + return self._unpack('%sB' % endian) + + def read_uint16(self, endian="<"): + """ + Read 2 byte as an unsigned integer value from the stream. + + Args: + endian (str): specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + + Returns: + int: + """ + return self._unpack('%sH' % endian, 2) + + def read_uint32(self, endian="<"): + """ + Read 4 bytes as an unsigned integer value from the stream. + + Args: + endian (str): specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + + Returns: + int: + """ + return self._unpack('%sI' % endian, 4) + + def read_uint64(self, endian="<"): + """ + Read 8 bytes as an unsigned integer value from the stream. + + Args: + endian (str): specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + + Returns: + int: + """ + return self._unpack('%sQ' % endian, 8) + + def read_var_int(self, max=sys.maxsize) -> int: + """ + Read a variable length integer from the stream. + The NEO network protocol supports encoded storage for space saving. See: http://docs.neo.org/en-us/node/network-protocol.html#convention + + Args: + max: (Optional) maximum number of bytes to read. + + Returns: + int: + """ + fb = int.from_bytes(self.read_byte(), 'little') + if fb is 0: + return fb + + if fb == 0xfd: + value = self.read_uint16() + elif fb == 0xfe: + value = self.read_uint32() + elif fb == 0xff: + value = self.read_uint64() + else: + value = fb + + if value > max: + raise ValueError("Invalid format") + + return value + + def read_var_bytes(self, max=sys.maxsize): + """ + Read a variable length of bytes from the stream. + The NEO network protocol supports encoded storage for space saving. See: http://docs.neo.org/en-us/node/network-protocol.html#convention + + Args: + max (int): (Optional) maximum number of bytes to read. + + Returns: + bytes: + """ + length = self.read_var_int(max) + return self.read_bytes(length) + + def read_var_string(self, max=sys.maxsize) -> str: + """ + Similar to `ReadString` but expects a variable length indicator instead of the fixed 1 byte indicator. + + Args: + max (int): (Optional) maximum number of bytes to read. + + Returns: + bytes: + """ + length = self.read_var_int(max) + try: + data = self._unpack(str(length) + 's', length) + return data.decode('utf-8') + + except UnicodeDecodeError as e: + raise e + except Exception as e: + raise e + + def read_fixed_string(self, length: int) -> str: + """ + Read a fixed length string from the stream. + + Args: + length (int): length of string to read. + + Raises: + ValueError: if not enough data could be read from the stream + + Returns: + str: + """ + return self.read_bytes(length).rstrip(b'\x00') + + def read_uint256(self): + """ + Read a UInt256 value from the stream. + + Returns: + UInt256: + """ + return UInt256(data=bytearray(self.read_bytes(32))) + + def read_uint160(self): + """ + Read a UInt160 value from the stream. + + Returns: + UInt160: + """ + return UInt160(data=bytearray(self.read_bytes(20))) + + def cleanup(self): + if self._stream: + StreamManager.ReleaseStream(self._stream) diff --git a/neo/Network/core/io/binary_writer.py b/neo/Network/core/io/binary_writer.py new file mode 100644 index 000000000..6a820c720 --- /dev/null +++ b/neo/Network/core/io/binary_writer.py @@ -0,0 +1,175 @@ +import struct +import binascii +import io +from typing import Union +from neo.IO.MemoryStream import StreamManager + + +class BinaryWriter(object): + """A convenience class for writing data from byte streams""" + + def __init__(self, stream: Union[bytearray, bytes]) -> None: + """ + Create an instance. + + Args: + stream: a stream to operate on. + """ + super(BinaryWriter, self).__init__() + self._stream = StreamManager.GetStream(stream) + + def write_bytes(self, value: bytes, unhex: bool = True) -> int: + """ + Write a `bytes` type to the stream. + Args: + value: array of bytes to write to the stream. + unhex: (Default) True. Set to unhexlify the stream. Use when the bytes are not raw bytes; i.e. b'aabb' + Returns: + int: the number of bytes written. + """ + if unhex: + try: + value = binascii.unhexlify(value) + except binascii.Error: + pass + return self._stream.write(value) + + def _pack(self, fmt, data) -> int: + """ + Write bytes by packing them according to the provided format `fmt`. + For more information about the `fmt` format see: https://docs.python.org/3/library/struct.html + Args: + fmt (str): format string. + data (object): the data to write to the raw stream. + Returns: + int: the number of bytes written. + """ + return self.write_bytes(struct.pack(fmt, data), unhex=False) + + def write_bool(self, value: bool) -> int: + """ + Pack the value as a bool and write 1 byte to the stream. + Args: + value: the boolean value to write. + Returns: + int: the number of bytes written. + """ + return self._pack('?', value) + + def write_uint8(self, value): + return self.write_bytes(bytes([value])) + + def write_uint16(self, value, endian="<"): + """ + Pack the value as an unsigned integer and write 2 bytes to the stream. + Args: + value: + endian: specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + Returns: + int: the number of bytes written. + """ + return self._pack('%sH' % endian, value) + + def write_uint32(self, value, endian="<") -> int: + """ + Pack the value as a signed integer and write 4 bytes to the stream. + Args: + value: + endian: specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + Returns: + int: the number of bytes written. + """ + return self._pack('%sI' % endian, value) + + def write_uint64(self, value, endian="<") -> int: + """ + Pack the value as an unsigned integer and write 8 bytes to the stream. + Args: + value: + endian (str): specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + Returns: + int: the number of bytes written. + """ + return self._pack('%sQ' % endian, value) + + def write_uint256(self, value, endian="<") -> int: + return self.write_bytes(value._data) + + def write_uint160(self, value, endian="<") -> int: + return self.write_bytes(value._data) + + def write_var_string(self, value: str, encoding: str = "utf-8") -> int: + """ + Write a string value to the stream. + Read more about variable size encoding here: http://docs.neo.org/en-us/node/network-protocol.html#convention + Args: + value: value to write to the stream. + encoding: string encoding format. + """ + if type(value) is str: + data = value.encode(encoding) + + length = len(data) + self.write_var_int(length) + written = self.write_bytes(data) + return written + + def write_var_int(self, value: int, endian: str = "<") -> int: + """ + Write an integer value in a space saving way to the stream. + Read more about variable size encoding here: http://docs.neo.org/en-us/node/network-protocol.html#convention + Args: + value: + endian: specify the endianness. (Default) Little endian ('<'). Use '>' for big endian. + Raises: + {TypeError}: if ``value`` is not of type int. + ValueError: if `value` is < 0. + Returns: + int: the number of bytes written. + """ + if not isinstance(value, int): + raise TypeError('%s not int type.' % value) + + if value < 0: + raise ValueError('%d too small.' % value) + + elif value < 0xfd: + return self.write_bytes(bytes([value])) + + elif value <= 0xffff: + self.write_bytes(bytes([0xfd])) + return self.write_uint16(value, endian) + + elif value <= 0xFFFFFFFF: + self.write_bytes(bytes([0xfe])) + return self.write_uint32(value, endian) + + else: + self.write_bytes(bytes([0xff])) + return self.write_uint64(value, endian) + + def write_fixed_string(self, value, length): + """ + Write a string value to the stream. + Args: + value (str): value to write to the stream. + length (int): length of the string to write. + """ + towrite = value.encode('utf-8') + slen = len(towrite) + if slen > length: + raise Exception("string longer than fixed length: %s " % length) + self.write_bytes(towrite) + diff = length - slen + + while diff > 0: + self.write_bytes(bytes([0])) + diff -= 1 + + def write_var_bytes(self, value: int, endian: str = "<") -> int: + self.write_var_int(len(value), endian) + return self.write_bytes(value) + + def cleanup(self): + if self._stream: + StreamManager.ReleaseStream(self._stream) diff --git a/neo/Network/core/io/test_binary_reader.py b/neo/Network/core/io/test_binary_reader.py new file mode 100644 index 000000000..36f715b91 --- /dev/null +++ b/neo/Network/core/io/test_binary_reader.py @@ -0,0 +1,39 @@ +import unittest +from neo.Network.core.io.binary_reader import BinaryReader + + +class BinaryReaderTest(unittest.TestCase): + def test_initialization_with_bytearray(self): + data = b'\xaa\xbb' + x = BinaryReader(stream=bytearray(data)) + self.assertTrue(data, x._stream.getvalue()) + + def test_reading_bytes(self): + data = b'\xaa\xbb\xCC' + x = BinaryReader(stream=bytearray(data)) + + read_one = x.read_byte() + self.assertEqual(1, len(read_one)) + self.assertEqual(b'\xaa', read_one) + + read_two = x.read_bytes(2) + self.assertEqual(2, len(read_two)) + self.assertEqual(b'\xbb\xcc', read_two) + + def test_read_more_data_than_available(self): + data = b'\xaa\xbb' + x = BinaryReader(stream=bytearray(data)) + + with self.assertRaises(ValueError) as context: + x.read_bytes(3) + expected_error = "Could not read 3 bytes from stream. Only found 2 bytes of data" + self.assertEqual(expected_error, str(context.exception)) + + def test_read_byte_from_empty_stream(self): + x = BinaryReader(stream=bytearray()) + + with self.assertRaises(ValueError) as context: + x.read_byte() + + expected_error = "Could not read byte from empty stream" + self.assertEqual(expected_error, str(context.exception)) diff --git a/neo/Network/core/io/test_binary_writer.py b/neo/Network/core/io/test_binary_writer.py new file mode 100644 index 000000000..97f9cf279 --- /dev/null +++ b/neo/Network/core/io/test_binary_writer.py @@ -0,0 +1,10 @@ +import unittest +from neo.Network.core.io.binary_writer import BinaryWriter + + +class BinaryWriterTest(unittest.TestCase): + def test_var_string(self): + data = "hello" + b = BinaryWriter(stream=bytearray()) + b.write_var_string(data) + self.assertTrue(data, b._stream.getvalue()) diff --git a/neo/Network/core/mixin/__init__.py b/neo/Network/core/mixin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/neo/Network/core/mixin/serializable.py b/neo/Network/core/mixin/serializable.py new file mode 100644 index 000000000..3f2e058a0 --- /dev/null +++ b/neo/Network/core/mixin/serializable.py @@ -0,0 +1,16 @@ +from abc import abstractmethod, ABC + + +class SerializableMixin(ABC): + + @abstractmethod + def serialize(self, writer) -> None: + pass + + @abstractmethod + def deserialize(self, reader) -> None: + pass + + @abstractmethod + def to_array(self) -> bytearray: + pass diff --git a/neo/Network/core/size.py b/neo/Network/core/size.py new file mode 100644 index 000000000..ecde7b90f --- /dev/null +++ b/neo/Network/core/size.py @@ -0,0 +1,68 @@ +from enum import IntEnum, Enum +from collections import Iterable +from neo.Network.core.mixin.serializable import SerializableMixin +from neo.Network.core.uintbase import UIntBase + +""" +This helper class is intended to help resolve the correct calculation of network serializable objects. +The result of `ctypes.sizeof` is not equivalent to C# or what we expect. See https://github.com/CityOfZion/neo-python/pull/418#issuecomment-389803377 +for more discussion on the topic. +""" + + +class Size(IntEnum): + """ + Explicit bytes of memory consumed + """ + uint8 = 1 + uint16 = 2 + uint32 = 4 + uint64 = 8 + uint160 = 20 + uint256 = 32 + + +def GetVarSize(value): + # public static int GetVarSize(this string value) + if isinstance(value, str): + value_size = len(value.encode('utf-8')) + return GetVarSize(value_size) + value_size + + # internal static int GetVarSize(int value) + elif isinstance(value, int): + if (value < 0xFD): + return Size.uint8 + elif (value <= 0xFFFF): + return Size.uint8 + Size.uint16 + else: + return Size.uint8 + Size.uint32 + + # internal static int GetVarSize(this T[] value) + elif isinstance(value, Iterable): + value_length = len(value) + value_size = 0 + + if value_length > 0: + if isinstance(value[0], SerializableMixin): + if isinstance(value[0], UIntBase): + # because the Size() method in UIntBase is implemented as a property + value_size = sum(map(lambda t: t.Size, value)) + else: + value_size = sum(map(lambda t: t.Size(), value)) + + elif isinstance(value[0], Enum): + # Note: currently all Enum's in neo core (C#) are of type Byte. Only porting that part of the code + value_size = value_length * Size.uint8 + elif isinstance(value, (bytes, bytearray)): + # experimental replacement for: value_size = value.Length * Marshal.SizeOf(); + # because I don't think we have a reliable 'SizeOf' in python + value_size = value_length * Size.uint8 + else: + raise TypeError( + "Can not accurately determine size of objects that do not inherit from 'SerializableMixin', 'Enum' or 'bytes'. Found type: {}".format( + type(value[0]))) + + else: + raise ValueError("[NOT SUPPORTED] Unexpected value type {} for GetVarSize()".format(type(value))) + + return GetVarSize(value_length) + value_size diff --git a/neo/Network/core/tests/__init__.py b/neo/Network/core/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/neo/Network/core/tests/test_uint_base.py b/neo/Network/core/tests/test_uint_base.py new file mode 100644 index 000000000..0ef7e2260 --- /dev/null +++ b/neo/Network/core/tests/test_uint_base.py @@ -0,0 +1,125 @@ +from unittest import TestCase +from neo.Network.core.uintbase import UIntBase + + +class UIntBaseTest(TestCase): + def test_create_with_empty_data(self): + x = UIntBase(num_bytes=2) + self.assertEqual(len(x._data), 2) + self.assertEqual(x._data, b'\x00\x00') + + def test_valid_data(self): + x = UIntBase(num_bytes=2, data=b'aabb') + # test for proper conversion to raw bytes + self.assertEqual(len(x._data), 2) + self.assertNotEqual(len(x._data), 4) + + x = UIntBase(num_bytes=3, data=bytearray.fromhex('aabbcc')) + self.assertEqual(len(x._data), 3) + self.assertNotEqual(len(x._data), 6) + + def test_valid_rawbytes_data(self): + x = UIntBase(num_bytes=2, data=b'\xaa\xbb') + self.assertEqual(len(x._data), 2) + self.assertNotEqual(len(x._data), 4) + + def test_invalid_data_type(self): + with self.assertRaises(TypeError) as context: + x = UIntBase(num_bytes=2, data='abc') + self.assertTrue("Invalid data type" in str(context.exception)) + + def test_raw_data_that_can_be_decoded(self): + """ + some raw data can be decoded e.g. bytearray.fromhex('1122') but shouldn't be + """ + tricky_raw_data = bytes.fromhex('1122') + x = UIntBase(num_bytes=2, data=tricky_raw_data) + self.assertEqual(x._data, tricky_raw_data) + + def test_data_length_mistmatch(self): + with self.assertRaises(ValueError) as context: + x = UIntBase(num_bytes=2, data=b'aa') # 2 != 1 + self.assertTrue("Invalid UInt: data length" in str(context.exception)) + + def test_size_property(self): + x = UIntBase(num_bytes=2, data=b'\xaa\xbb') + self.assertEqual(x.size, 2) + + def test_hash_code(self): + x = UIntBase(num_bytes=4, data=bytearray.fromhex('DEADBEEF')) + self.assertEqual(x.get_hash_code(), 4022250974) + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + self.assertEqual(x.get_hash_code(), 8721) + + def test_serialize(self): + pass + + def test_deserialize(self): + pass + + def test_to_array(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + expected = b'\x11\x22' + self.assertEqual(expected, x.to_array()) + + def test_to_string(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + self.assertEqual('2211', x.to_string()) + self.assertEqual('2211', str(x)) + self.assertNotEqual('1122', x.to_string()) + self.assertNotEqual('1122', str(x)) + + def test_equal(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + y = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + z = UIntBase(num_bytes=2, data=bytearray.fromhex('2211')) + + self.assertFalse(x is None) + self.assertFalse(x == int(1122)) + self.assertTrue(x == x) + self.assertTrue(x == y) + self.assertTrue(x != z) + + def test_hash(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + y = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + z = UIntBase(num_bytes=2, data=bytearray.fromhex('2211')) + self.assertEqual(hash(x), hash(y)) + self.assertNotEqual(hash(x), hash(z)) + + def test_compare_to(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + y = UIntBase(num_bytes=3, data=bytearray.fromhex('112233')) + z = UIntBase(num_bytes=2, data=bytearray.fromhex('1133')) + xx = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + + # test invalid type + with self.assertRaises(TypeError) as context: + x._compare_to(None) + + expected = "Cannot compare UIntBase to type NoneType" + self.assertEqual(expected, str(context.exception)) + + # test invalid length + with self.assertRaises(ValueError) as context: + x._compare_to(y) + + expected = "Cannot compare UIntBase with length 2 to UIntBase with length 3" + self.assertEqual(expected, str(context.exception)) + + # test data difference ('22' < '33') + self.assertEqual(-1, x._compare_to(z)) + # test data difference ('33' > '22') + self.assertEqual(1, z._compare_to(x)) + # test data equal + self.assertEqual(0, x._compare_to(xx)) + + def test_rich_comparison_methods(self): + x = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + z = UIntBase(num_bytes=2, data=bytearray.fromhex('1133')) + xx = UIntBase(num_bytes=2, data=bytearray.fromhex('1122')) + + self.assertTrue(x < z) + self.assertTrue(z > x) + self.assertTrue(x <= xx) + self.assertTrue(x >= xx) diff --git a/neo/Network/core/uint160.py b/neo/Network/core/uint160.py new file mode 100644 index 000000000..0d98d2b35 --- /dev/null +++ b/neo/Network/core/uint160.py @@ -0,0 +1,20 @@ +from neo.Network.core.uintbase import UIntBase + + +class UInt160(UIntBase): + def __init__(self, data=None): + super(UInt160, self).__init__(num_bytes=20, data=data) + + @staticmethod + def from_string(value): + if value[0:2] == '0x': + value = value[2:] + if not len(value) == 40: + raise ValueError(f"Invalid UInt160 Format: {len(value)} chars != 40 chars") + reversed_data = bytearray.fromhex(value) + reversed_data.reverse() + return UInt160(data=reversed_data) + + @classmethod + def zero(cls): + return cls(data=bytearray(20)) diff --git a/neo/Network/core/uint256.py b/neo/Network/core/uint256.py new file mode 100644 index 000000000..615f83660 --- /dev/null +++ b/neo/Network/core/uint256.py @@ -0,0 +1,20 @@ +from neo.Network.core.uintbase import UIntBase + + +class UInt256(UIntBase): + def __init__(self, data=None): + super(UInt256, self).__init__(num_bytes=32, data=data) + + @staticmethod + def from_string(value): + if value[0:2] == '0x': + value = value[2:] + if not len(value) == 64: + raise ValueError(f"Invalid UInt256 Format: {len(value)} chars != 64 chars") + reversed_data = bytearray.fromhex(value) + reversed_data.reverse() + return UInt256(data=reversed_data) + + @classmethod + def zero(cls): + return cls(data=bytearray(32)) diff --git a/neo/Network/core/uintbase.py b/neo/Network/core/uintbase.py new file mode 100644 index 000000000..bdc4008d5 --- /dev/null +++ b/neo/Network/core/uintbase.py @@ -0,0 +1,124 @@ +import binascii +from neo.Network.core.mixin import serializable +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from neo.Network.core import BinaryReader + from neo.Network.core.io.binary_writer import BinaryWriter + + +class UIntBase(serializable.SerializableMixin): + _data = bytearray() + _hash: int = 0 + + def __init__(self, num_bytes: int, data: Union[bytes, bytearray] = None) -> None: + super(UIntBase, self).__init__() + + if data is None: + self._data = bytearray(num_bytes) + + else: + if isinstance(data, bytes): + # make sure it's mutable for string representation + self._data = bytearray(data) + elif isinstance(data, bytearray): + self._data = data + else: + raise TypeError("Invalid data type {}. Expecting bytes or bytearray".format(type(data))) + + # now make sure we're working with raw bytes + try: + self._data = bytearray(binascii.unhexlify(self._data.decode())) + except UnicodeDecodeError: + # decode() fails most of the time if data is already in raw bytes. In that case there is nothing to be done. + pass + except binascii.Error: + # however in some cases like bytes.fromhex('1122') decoding passes, + # but binascii fails because it was actually already in rawbytes. Still nothing to be done. + pass + + if len(self._data) != num_bytes: + raise ValueError("Invalid UInt: data length {} != specified num_bytes {}".format(len(self._data), num_bytes)) + + self._hash = self.get_hash_code() + + @property + def size(self) -> int: + """ Count of data bytes. """ + return len(self._data) + + def get_hash_code(self) -> int: + """ Get a uint32 identifier. """ + slice_length = 4 if len(self._data) >= 4 else len(self._data) + return int.from_bytes(self._data[:slice_length], 'little') + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_bytes(self._data) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self._data = reader.read_bytes(self.size) + + def to_array(self) -> bytearray: + """ get the raw data. """ + return self._data + + def to_string(self) -> str: + """ Convert the data to a human readable format (data is in reverse order). """ + db = bytearray(self._data) + db.reverse() + return db.hex() + + def __eq__(self, other) -> bool: + if other is None: + return False + + if not isinstance(other, UIntBase): + return False + + if other is self: + return True + + if self._data == other._data: + return True + + return False + + def __hash__(self): + return self._hash + + def __str__(self): + return self.to_string() + + def _compare_to(self, other) -> int: + if not isinstance(other, UIntBase): + raise TypeError('Cannot compare %s to type %s' % (type(self).__name__, type(other).__name__)) + + x = self.to_array() + y = other.to_array() + + if len(x) != len(y): + raise ValueError('Cannot compare %s with length %s to %s with length %s' % (type(self).__name__, len(x), type(other).__name__, len(y))) + + length = len(x) + + for i in range(length - 1, 0, -1): + if x[i] > y[i]: + return 1 + if x[i] < y[i]: + return -1 + + return 0 + + def __lt__(self, other): + return self._compare_to(other) < 0 + + def __gt__(self, other): + return self._compare_to(other) > 0 + + def __le__(self, other): + return self._compare_to(other) <= 0 + + def __ge__(self, other): + return self._compare_to(other) >= 0 diff --git a/neo/Network/flightinfo.py b/neo/Network/flightinfo.py new file mode 100644 index 000000000..978d4cf05 --- /dev/null +++ b/neo/Network/flightinfo.py @@ -0,0 +1,11 @@ +from datetime import datetime + + +class FlightInfo: + def __init__(self, node_id, height): + self.node_id: int = node_id + self.height: int = height + self.start_time: int = datetime.utcnow().timestamp() + + def reset_start_time(self): + self.start_time = datetime.utcnow().timestamp() diff --git a/neo/Network/ipfilter.py b/neo/Network/ipfilter.py new file mode 100644 index 000000000..3ce2265d5 --- /dev/null +++ b/neo/Network/ipfilter.py @@ -0,0 +1,88 @@ +from ipaddress import IPv4Network +from contextlib import suppress + +""" + A class for filtering IPs. + + * The whitelist has precedence over the blacklist settings + * Host masks can be applied + * When using host masks do not set host bits (leave them to 0) or an exception will occur + + Common scenario examples: + + 1) Accept only specific trusted IPs + { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + '10.10.10.10', + '15.15.15.15' + ] + } +2) Accept only a range of trusted IPs + # accepts any IP in the range of 10.10.10.0 - 10.10.10.255 + { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + '10.10.10.0/24', + ] + } + +3 ) Accept everybody except specific IPs + # can be used for banning bad actors + { + 'blacklist': [ + '12.12.12.12', + '13.13.13.13' + ], + 'whitelist': [ + ] + } + + +""" + + +class IPFilter(): + config = {'blacklist': [], 'whitelist': []} + + def is_allowed(self, host_address) -> bool: + address = IPv4Network(host_address) + + is_allowed = True + + for ip in self.config['blacklist']: + disallowed = IPv4Network(ip) + if disallowed.overlaps(address): + is_allowed = False + break + else: + return is_allowed + + # can override blacklist + for ip in self.config['whitelist']: + allowed = IPv4Network(ip) + if allowed.overlaps(address): + is_allowed = True + + return is_allowed + + def blacklist_add(self, address) -> None: + self.config['blacklist'].append(address) + + def blacklist_remove(self, address) -> None: + with suppress(ValueError): + self.config['blacklist'].remove(address) + + def whitelist_add(self, address) -> None: + self.config['whitelist'].append(address) + + def whitelist_remove(self, address) -> None: + with suppress(ValueError): + self.config['whitelist'].remove(address) + + +ipfilter = IPFilter() diff --git a/neo/Network/ledger.py b/neo/Network/ledger.py new file mode 100644 index 000000000..47094d9e9 --- /dev/null +++ b/neo/Network/ledger.py @@ -0,0 +1,85 @@ +import binascii +import asyncio +from typing import TYPE_CHECKING, List +from neo.Core.Blockchain import Blockchain +from neo.Core.Block import Block +from neo.IO.Helper import Helper as IOHelper +from neo.Network.core.uint256 import UInt256 +from neo.logging import log_manager +import traceback + +logger = log_manager.getLogger('db') + +if TYPE_CHECKING: + from neo.Network.core import Header + from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain + + +class Ledger: + def __init__(self, controller=None): + self.controller = controller + self.ledger = Blockchain.Default() # type: LevelDBBlockchain + + async def cur_header_height(self) -> int: + return self.ledger.HeaderHeight + # return await self.controller.get_current_header_height() + + async def cur_block_height(self) -> int: + # return await self.controller.get_current_block_height() + return self.ledger.Height + + async def header_hash_by_height(self, height: int) -> 'UInt256': + # return await self.controller.get_header_hash_by_height(height) + header_hash = self.ledger.GetHeaderHash(height) + if header_hash is None: + data = bytearray(32) + else: + data = bytearray(binascii.unhexlify(header_hash)) + data.reverse() + return UInt256(data=data) + + async def add_headers(self, network_headers: List['Header']) -> int: + """ + + Args: + headers: + + Returns: number of headers added + + """ + headers = [] + count = 0 + for h in network_headers: + header = IOHelper.AsSerializableWithType(h.to_array(), 'neo.Core.Header.Header') + if header is None: + break + else: + headers.append(header) + # just making sure we don't block too long while converting + await asyncio.sleep(0.001) + else: + count = self.ledger.AddHeaders(headers) + + return count + + async def add_block(self, raw_block: bytes) -> bool: + # return await self.controller.add_block(block) + block = IOHelper.AsSerializableWithType(raw_block, 'neo.Core.Block.Block') # type: Block + + if block is None: + return False + else: + self.ledger.AddHeader(block.Header) + + success, reason = await self.ledger.TryPersist(block) + if not success: + logger.debug(f"Failed to Persist block. Reason: {reason}") + return False + + try: + self.ledger.OnPersistCompleted(block) + except Exception as e: + traceback.print_exc() + logger.debug(f"Failed to broadcast OnPersistCompleted event, reason: {e}") + + return True diff --git a/neo/Network/mempool.py b/neo/Network/mempool.py new file mode 100644 index 000000000..62b2dd442 --- /dev/null +++ b/neo/Network/mempool.py @@ -0,0 +1,42 @@ +from contextlib import suppress + +from neo.Core.Block import Block as OrigBlock +from neo.Core.Blockchain import Blockchain as BC +from neo.Network.common import msgrouter +from neo.Network.common.singleton import Singleton +from neo.logging import log_manager + +logger = log_manager.getLogger('network') + + +class MemPool(Singleton): + def init(self): + self.pool = dict() + msgrouter.on_block_persisted += self.update_pool_for_block_persist + + def add_transaction(self, tx) -> bool: + if BC.Default() is None: + return False + + if tx.Hash.ToString() in self.pool.keys(): + return False + + if BC.Default().ContainsTransaction(tx.Hash): + return False + + if not tx.Verify(self.pool.values()): + logger.error("Verifying tx result... failed") + return False + + self.pool[tx.Hash] = tx + + return True + + def update_pool_for_block_persist(self, orig_block: OrigBlock) -> None: + for tx in orig_block.Transactions: + with suppress(KeyError): + self.pool.pop(tx.Hash) + logger.debug(f"Found {tx.Hash} in last persisted block. Removing from mempool") + + def reset(self) -> None: + self.pool = dict() diff --git a/neo/Network/message.py b/neo/Network/message.py new file mode 100644 index 000000000..2748f5ebd --- /dev/null +++ b/neo/Network/message.py @@ -0,0 +1,104 @@ +import hashlib +from typing import Union +from typing import TYPE_CHECKING, Optional +from neo.Network.payloads.base import BasePayload +from neo.Network.core.mixin.serializable import SerializableMixin +from neo.Network.core.size import Size as s +from neo.Network.core.io.binary_writer import BinaryWriter + +bytes_or_payload = Union[bytes, BasePayload] + +if TYPE_CHECKING: + from neo.Network.core import BinaryReader + + +class ChecksumException(Exception): + pass + + +class Message(SerializableMixin): + _payload_max_size = int.from_bytes(bytes.fromhex('02000000'), 'big') + _magic = None + + def __init__(self, magic: Optional[int] = None, command: Optional[str] = None, payload: Optional[bytes_or_payload] = None) -> None: + """ + Create an instance. + + Args: + command: max 12 bytes, utf-8 encoded payload command + payload: raw bytes of the payload. + """ + self.command = command + if magic: + self.magic = magic + else: + # otherwise set to class variable. + self.magic = self._magic + + self.payload_length = 0 + if payload is None: + self.payload = bytearray() + else: + if isinstance(payload, BasePayload): + self.payload = payload.to_array() + else: + self.payload = payload + self.payload_length = len(self.payload) + + self.checksum = None + + def __len__(self) -> int: + return self.size() + + def size(self) -> int: + """ Get the total size in bytes of the object. """ + return s.uint32 + 12 + s.uint32 + s.uint32 + len(self.payload) + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_uint32(self.magic) + writer.write_fixed_string(self.command, 12) + writer.write_uint32(len(self.payload)) + writer.write_uint32(self.get_checksum()) + writer.write_bytes(self.payload) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize full object. """ + self.magic = reader.read_uint32() + self.command = reader.read_fixed_string(12) + self.payload_length = reader.read_uint32() + + if self.payload_length > self._payload_max_size: + raise ValueError("Specified payload length exceeds maximum payload length") + + self.checksum = reader.read_uint32() + self.payload = reader.read_bytes(self.payload_length) + + checksum = self.get_checksum() + + if checksum != self.checksum: + raise ChecksumException("checksum mismatch") + + def get_checksum(self, value: Optional[Union[bytes, bytearray]] = None) -> int: + """ + Get the double SHA256 hash of the value. + + Args: + value (raw bytes): a payload + + Returns: + int: checksum + """ + if not value: + value = self.payload + + uint32 = hashlib.sha256(hashlib.sha256(value).digest()).digest() + x = uint32[:4] + return int.from_bytes(x, 'little') + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/node.py b/neo/Network/node.py new file mode 100644 index 000000000..d58242ba1 --- /dev/null +++ b/neo/Network/node.py @@ -0,0 +1,291 @@ +from neo.Network.message import Message +from neo.Network.payloads.version import VersionPayload +from neo.Network.payloads.getblocks import GetBlocksPayload +from neo.Network.payloads.addr import AddrPayload +from neo.Network.payloads.networkaddress import NetworkAddressWithTime +from neo.Network.payloads.inventory import InventoryPayload, InventoryType +from neo.Network.payloads.block import Block +from neo.Network.payloads.headers import HeadersPayload +from neo.Network.payloads.ping import PingPayload +from neo.Network.core.uint256 import UInt256 +from neo.Network.core.header import Header +from neo.Network.ipfilter import ipfilter +from neo.Blockchain import GetBlockchain +from datetime import datetime +from typing import Optional, List, TYPE_CHECKING +import asyncio +from contextlib import suppress +from neo.Network.common import msgrouter, encode_base62 +from neo.Network.nodeweight import NodeWeight +from neo.logging import log_manager +from neo.Settings import settings +import binascii + +logger = log_manager.getLogger('network') + +if TYPE_CHECKING: + from neo.Network.nodemanager import NodeManager + from neo.Network.protocol import NeoProtocol + + +class NeoNode: + def __init__(self, protocol: 'NeoProtocol', nodemanager: 'NodeManager', quality_check=False): + self.protocol = protocol + self.nodemanager = nodemanager + self.quality_check = quality_check + + self.address = None + self.nodeid = id(self) + self.nodeid_human = encode_base62(self.nodeid) + self.version = None + self.tasks = [] + self.nodeweight = NodeWeight(self.nodeid) + self.best_height = 0 # track the block height of node + + self._inv_hash_for_height = None # temp variable to track which hash we used for determining the nodes best height + self.main_task = None # type: asyncio.Task + self.disconnecting = False + + # connection setup and control functions + async def connection_made(self, transport) -> None: + addr_tuple = self.protocol._stream_writer.get_extra_info('peername') + self.address = f"{addr_tuple[0]}:{addr_tuple[1]}" + + if not ipfilter.is_allowed(addr_tuple[0]): + await self.disconnect() + return + + task = asyncio.create_task(self.do_handshake()) + self.tasks.append(task) # storing the task in case the connection is lost before it finishes the task, this allows us to cancel the task + await task + self.tasks.remove(task) + + async def do_handshake(self) -> None: + send_version = Message(command='version', payload=VersionPayload(port=settings.NODE_PORT, userAgent=settings.VERSION_NAME)) + await self.send_message(send_version) + + m = await self.read_message(timeout=3) + if not m or m.command != 'version': + await self.disconnect() + return + + if not self.validate_version(m.payload): + await self.disconnect() + return + + m_verack = Message(command='verack') + await self.send_message(m_verack) + + m = await self.read_message(timeout=3) + if not m or m.command != 'verack': + await self.disconnect() + return + + if self.quality_check: + self.nodemanager.quality_check_result(self.address, healthy=True) + await self.disconnect() + return + else: + logger.debug(f"Connected to {self.version.user_agent} @ {self.address}: {self.version.start_height}") + self.nodemanager.add_connected_node(self) + self.main_task = asyncio.create_task(self.run()) + # when we break out of the run loop, we should make sure we disconnect + self.main_task.add_done_callback(lambda _: asyncio.create_task(self.disconnect())) + + async def disconnect(self) -> None: + self.disconnecting = True + + for t in self.tasks: + t.cancel() + with suppress(asyncio.CancelledError): + await t + self.nodemanager.remove_connected_node(self) + self.protocol.disconnect() + + async def connection_lost(self, exc) -> None: + logger.debug(f"{datetime.now()} Connection lost {self.address} excL {exc}") + + await self.disconnect() + + if self.quality_check: + self.nodemanager.quality_check_result(self.address, healthy=False) + + def validate_version(self, data) -> bool: + try: + self.version = VersionPayload.deserialize_from_bytes(data) + except ValueError: + logger.debug("failed to deserialize Version") + return False + + if self.version.nonce == self.nodeid: + logger.debug("Client is self") + return False + + # update nodes height indicator + self.best_height = self.version.start_height + + return True + + async def run(self) -> None: + """ + Main loop + """ + logger.debug("Waiting for a message") + while not self.disconnecting: + # we want to always listen for an incoming message + + message = await self.read_message(timeout=0) + if not message: + continue + + if message.command == 'addr': + addr_payload = AddrPayload.deserialize_from_bytes(message.payload) + for a in addr_payload.addresses: + msgrouter.on_addr(f"{a.address}:{a.port}") + elif message.command == 'getaddr': + await self.send_address_list() + elif message.command == 'inv': + inv = InventoryPayload.deserialize_from_bytes(message.payload) + if not inv: + continue + + if inv.type == InventoryType.block: + # neo-cli broadcasts INV messages on a regular interval. We can use those as trigger to request their latest block height + # supported from 2.10.0.1 onwards + if len(inv.hashes) > 0: + m = Message(command='ping', payload=PingPayload(GetBlockchain().Height)) + await self.send_message(m) + # self._inv_hash_for_height = inv.hashes[-1] + # await self.get_data(inv.type, inv.hashes) + elif inv.type == InventoryType.consensus: + pass + elif inv.type == InventoryType.tx: + pass + elif message.command == 'block': + block = Block.deserialize_from_bytes(message.payload) + if block: + if self._inv_hash_for_height == block.hash and block.index > self.best_height: + logger.debug(f"Updating node {self.nodeid_human} height from {self.best_height} to {block.index}") + self.best_height = block.index + self._inv_hash_for_height = None + + await msgrouter.on_block(self.nodeid, block, message.payload) + elif message.command == 'headers': + header_payload = HeadersPayload.deserialize_from_bytes(message.payload) + + if header_payload and len(header_payload.headers) > 0: + await msgrouter.on_headers(self.nodeid, header_payload.headers) + del header_payload + + elif message.command == 'pong': + payload = PingPayload.deserialize_from_bytes(message.payload) + if payload: + logger.debug(f"Updating node {self.nodeid_human} height from {self.best_height} to {payload.current_height}") + self.best_height = payload.current_height + self._inv_hash_for_height = None + elif message.command == 'getdata': + inv = InventoryPayload.deserialize_from_bytes(message.payload) + if not inv: + continue + + for h in inv.hashes: + item = self.nodemanager.relay_cache.try_get(h) + if item is None: + # for the time being we only support data retrieval for our own relays + continue + if inv.type == InventoryType.tx: + raw_payload = binascii.unhexlify(item.ToArray()) + m = Message(command='tx', payload=raw_payload) # this is still an old code base InventoryMixin type + await self.send_message(m) + else: + if message.command not in ['consensus', 'getheaders']: + logger.debug(f"Message with command: {message.command}") + + # raw network commands + async def get_address_list(self) -> None: + """ Send a request for receiving known addresses""" + m = Message(command='getaddr') + await self.send_message(m) + + async def send_address_list(self) -> None: + """ Send our known addresses """ + known_addresses = [] + for node in self.nodemanager.nodes: + host, port = node.address.split(':') + if host and port: + known_addresses.append(NetworkAddressWithTime(address=host, port=int(port))) + if len(known_addresses) > 0: + m = Message(command='address', payload=AddrPayload(addresses=known_addresses)) + await self.send_message(m) + + async def get_headers(self, hash_start: UInt256, hash_stop: Optional[UInt256] = None) -> None: + """ Send a request for headers from `hash_start` + 1 to `hash_stop` + + Not specifying a `hash_stop` results in requesting at most 2000 headers. + """ + m = Message(command='getheaders', payload=GetBlocksPayload(hash_start, hash_stop)) + await self.send_message(m) + + async def send_headers(self, headers: List[Header]) -> None: + """ Send a list of Header objects. + + This is usually done as a response to a 'getheaders' request. + """ + if len(headers) > 2000: + headers = headers[:2000] + + m = Message(command='headers', payload=HeadersPayload(headers)) + await self.send_message(m) + + async def get_blocks(self, hash_start: UInt256, hash_stop: Optional[UInt256] = None) -> None: + """ Send a request for blocks from `hash_start` + 1 to `hash_stop` + + Not specifying a `hash_stop` results in requesting at most 500 blocks. + """ + m = Message(command='getblocks', payload=GetBlocksPayload(hash_start, hash_stop)) + await self.send_message(m) + + async def get_data(self, type: InventoryType, hashes: List[UInt256]) -> None: + """ Send a request for receiving the specified inventory data.""" + if len(hashes) < 1: + return + + m = Message(command='getdata', payload=InventoryPayload(type, hashes)) + await self.send_message(m) + + async def relay(self, inventory) -> bool: + """ + Try to relay the inventory to the network + + Args: + inventory: should be of type Block, Transaction or ConsensusPayload (see: InventoryType) + + Returns: False if inventory is already in the mempool, or if relaying to nodes failed (e.g. because we have no nodes connected) + + """ + # TODO: this is based on the current/old neo-python Block, Transaction and ConsensusPlayload classes + # meaning attribute naming will change (no longer camelCase) once we move to python naming convention + # for now we need to convert them to our new types or calls will fail + new_inventorytype = InventoryType(inventory.InventoryType) + new_hash = UInt256(data=inventory.Hash.ToArray()) + inv = InventoryPayload(type=new_inventorytype, hashes=[new_hash]) + m = Message(command='inv', payload=inv) + await self.send_message(m) + + return True + + # utility functions + async def send_message(self, message: Message) -> None: + await self.protocol.send_message(message) + + async def read_message(self, timeout: int = 30) -> Optional[Message]: + return await self.protocol.read_message(timeout) + + def __eq__(self, other): + if type(other) is type(self): + return self.address == other.address and self.nodeid == other.nodeid + else: + return False + + def __repr__(self): + return f"<{self.__class__.__name__} at {hex(id(self))}> {self.nodeid_human}" diff --git a/neo/Network/nodemanager.py b/neo/Network/nodemanager.py new file mode 100644 index 000000000..0593b9a71 --- /dev/null +++ b/neo/Network/nodemanager.py @@ -0,0 +1,381 @@ +import asyncio +import socket +import traceback +import errno +from contextlib import suppress +from datetime import datetime +from functools import partial +from socket import AF_INET as IP4_FAMILY +from typing import Optional, List + +from neo.Core.TX.Transaction import Transaction as OrigTransaction +from neo.Network.common import msgrouter, wait_for +from neo.Network.common.singleton import Singleton +from neo.Network import utils as networkutils +from neo.Network.mempool import MemPool +from neo.Network.node import NeoNode +from neo.Network.protocol import NeoProtocol +from neo.Network.relaycache import RelayCache +from neo.Network.requestinfo import RequestInfo +from neo.Settings import settings +from neo.logging import log_manager + + +logger = log_manager.getLogger('network') +# log_manager.config_stdio([('network', 10)]) + + +class NodeManager(Singleton): + PEER_QUERY_INTERVAL = 15 + NODE_POOL_CHECK_INTERVAL = 10 # 2.5 * PEER_QUERY_INTERVAL # this allows for enough time to get new addresses + + ONE_MINUTE = 60 + + MAX_ERROR_COUNT = 5 # maximum number of times adding a block or header may fail before we disconnect it + MAX_TIMEOUT_COUNT = 15 # maximum count the node responds slower than our threshold + + MAX_NODE_POOL_ERROR = 2 + MAX_NODE_POOL_ERROR_COUNT = 0 + + # we override init instead of __init__ due to the Singleton (read class documentation) + def init(self): + self.loop = asyncio.get_event_loop() + self.max_clients = settings.CONNECTED_PEER_MAX + self.min_clients = settings.CONNECTED_PEER_MIN + self.id = id(self) + self.mempool = MemPool() + + # a list of nodes that we're actively using to request data from + self.nodes = [] # type: List[NeoNode] + # a list of host:port addresses that have a task pending to to connect to, but are not fully processed + self.queued_addresses = [] + # a list of addresses which we know are bad. Reasons include; failed to connect, went offline, poor performance + self.bad_addresses = [] + # a list of addresses that we've tested to be alive but that we're currently not connected to because we've + # reached our `max_clients` setting. We use these addresses to quickly replace a bad node + self.known_addresses = [] + + self.connection_queue = asyncio.Queue() + + # a list for gathering tasks such that we can manually determine the order of shutdown + self.tasks = [] + self.shutting_down = False + + self.relay_cache = RelayCache() + + msgrouter.on_addr += self.on_addr_received + + self.running = False + + async def start(self): + host = 'localhost' + port = 8888 # settings.NODE_PORT + proto = partial(NeoProtocol, nodemanager=self) + + try: + await self.loop.create_server(proto, host, port) + except OSError as e: + if e.errno == errno.EADDRINUSE: + print(f"Node address {host}:{port} already in use ") + raise SystemExit + else: + raise e + print(f"[{datetime.now()}] Running P2P network on {host} {port}") + + for seed in settings.SEED_LIST: + host, port = seed.split(':') + if not networkutils.is_ip_address(host): + try: + # TODO: find a way to make socket.gethostbyname non-blocking as it can take very long to look up + # using loop.run_in_executor was unsuccessful. + host = networkutils.hostname_to_ip(host) + except socket.gaierror as e: + logger.debug(f"Skipping {host}, address could not be resolved: {e}") + continue + + self.known_addresses.append(f"{host}:{port}") + + self.tasks.append(asyncio.create_task(self.handle_connection_queue())) + self.tasks.append(asyncio.create_task(self.query_peer_info())) + self.tasks.append(asyncio.create_task(self.ensure_full_node_pool())) + + self.running = True + + async def handle_connection_queue(self) -> None: + while not self.shutting_down: + addr, quality_check = await self.connection_queue.get() + task = asyncio.create_task(self._connect_to_node(addr, quality_check)) + self.tasks.append(task) + task.add_done_callback(lambda fut: self.tasks.remove(fut)) + + async def query_peer_info(self) -> None: + while not self.shutting_down: + logger.debug(f"Connected node count {len(self.nodes)}") + for node in self.nodes: + task = asyncio.create_task(node.get_address_list()) + self.tasks.append(task) + task.add_done_callback(lambda fut: self.tasks.remove(fut)) + await asyncio.sleep(self.PEER_QUERY_INTERVAL) + + async def ensure_full_node_pool(self) -> None: + while not self.shutting_down: + self.check_open_spots_and_queue_nodes() + await asyncio.sleep(self.NODE_POOL_CHECK_INTERVAL) + + def check_open_spots_and_queue_nodes(self) -> None: + open_spots = self.max_clients - (len(self.nodes) + len(self.queued_addresses)) + + if open_spots > 0: + logger.debug(f"Found {open_spots} open pool spots, trying to add nodes...") + for _ in range(open_spots): + try: + addr = self.known_addresses.pop(0) + self.queue_for_connection(addr) + except IndexError: + # oh no, we've exhausted our good addresses list + if len(self.nodes) < self.min_clients: + if self.MAX_NODE_POOL_ERROR_COUNT != self.MAX_NODE_POOL_ERROR: + # give our `query_peer_info` loop a chance to collect new addresses + self.MAX_NODE_POOL_ERROR_COUNT += 1 + break + else: + # we have no other option then to retry any address we know + logger.debug("Recycling old addresses") + self.known_addresses = self.bad_addresses + self.bad_addresses = [] + self.MAX_NODE_POOL_ERROR_COUNT = 0 + + def add_connected_node(self, node: NeoNode) -> None: + if node not in self.nodes and not self.shutting_down: + self.nodes.append(node) + + if node.address in self.queued_addresses: + self.queued_addresses.remove(node.address) + + def remove_connected_node(self, node: NeoNode) -> None: + with suppress(ValueError): + self.queued_addresses.remove(node.address) + + with suppress(ValueError): + self.nodes.remove(node) + + def get_next_node(self, height: int) -> Optional[NeoNode]: + """ + + Args: + height: the block height for which we're requesting data. Used to filter nodes that have this data + + Returns: + + """ + if len(self.nodes) == 0: + return None + + weights = list(map(lambda n: n.nodeweight, self.nodes)) + # highest weight is taken first + weights.sort(reverse=True) + + for weight in weights: + node = self.get_node_by_nodeid(weight.id) + if node and height <= node.best_height: + return node + else: + # we could not find a node with the height we're looking for + return None + + async def replace_node(self, node) -> None: + if node.address not in self.bad_addresses: + self.bad_addresses.append(node.address) + + asyncio.create_task(node.disconnect()) + + with suppress(IndexError): + addr = self.known_addresses.pop(0) + self.queue_for_connection(addr) + + async def add_node_error_count(self, nodeid: int) -> None: + node = self.get_node_by_nodeid(nodeid) + if node: + node.nodeweight.error_response_count += 1 + + if node.nodeweight.error_response_count > self.MAX_ERROR_COUNT: + logger.debug(f"Disconnecting node {node.nodeid} Reason: max error count threshold exceeded") + await self.replace_node(node) + + async def add_node_timeout_count(self, nodeid: int) -> None: + node = self.get_node_by_nodeid(nodeid) + if node: + node.nodeweight.timeout_count += 1 + + if node.nodeweight.timeout_count > self.MAX_TIMEOUT_COUNT: + logger.debug(f"Disconnecting node {node.nodeid} Reason: max timeout count threshold exceeded") + await self.replace_node(node) + + def get_node_with_min_failed_time(self, ri: RequestInfo) -> Optional[NeoNode]: + # Find the node with the least failures for the item in RequestInfo + + least_failed_times = 999 + least_failed_node = None + tried_nodes = [] + + while True: + node = self.get_next_node(ri.height) + if not node: + return None + + failed_times = ri.failed_nodes.get(node.nodeid, 0) + if failed_times == 0: + # return the node we haven't tried this request on before + return node + + if node.nodeid in tried_nodes: + # we've exhausted the node list and should just go with our best available option + return least_failed_node + + tried_nodes.append(node.nodeid) + if failed_times < least_failed_times: + least_failed_times = failed_times + least_failed_node = node + + def get_node_by_nodeid(self, nodeid: int) -> Optional[NeoNode]: + for n in self.nodes: + if n.nodeid == nodeid: + return n + else: + return None + + def connected_addresses(self) -> List[str]: + return list(map(lambda n: n.address, self.nodes)) + + def on_addr_received(self, addr) -> None: + if addr in self.bad_addresses or addr in self.queued_addresses or addr in self.known_addresses: + # we received a duplicate + return + + if addr not in self.connected_addresses(): + self.known_addresses.append(addr) + # it's a new address, see if we can make it part of the current connection pool + if len(self.nodes) + len(self.queued_addresses) < self.max_clients: + self.queue_for_connection(addr) + else: + # current pool is full, but.. + # we can test out the new addresses ahead of time as we might receive dead + # or poor performing addresses from neo-cli nodes + self.queue_for_connection(addr, only_quality_check=True) + + def quality_check_result(self, addr, healthy) -> None: + if addr is None: + logger.debug("WARNING QUALITY CHECK ADDR IS NONE!") + if not healthy: + with suppress(ValueError): + self.known_addresses.remove(addr) + + if addr not in self.bad_addresses: + self.bad_addresses.append(addr) + + def queue_for_connection(self, addr, only_quality_check=False) -> None: + if only_quality_check: + # quality check connections will disconnect after a successful handshake + # they should not count towards the total connected nodes list + logger.debug(f"Adding {addr} to connection queue for quality checking") + task = asyncio.create_task(self.connection_queue.put((addr, only_quality_check))) + self.tasks.append(task) + task.add_done_callback(lambda fut: self.tasks.remove(fut)) + else: + # check if there is space for another node according to our max clients settings + if len(self.nodes) + len(self.queued_addresses) < self.max_clients: + # regular connections should count towards the total connected nodes list + if addr not in self.queued_addresses and addr not in self.connected_addresses(): + self.queued_addresses.append(addr) + logger.debug(f"Adding {addr} to connection queue") + task = asyncio.create_task(self.connection_queue.put((addr, only_quality_check))) + self.tasks.append(task) + task.add_done_callback(lambda fut: self.tasks.remove(fut)) + + def relay(self, inventory) -> bool: + if type(inventory) is OrigTransaction or issubclass(type(inventory), OrigTransaction): + success = self.mempool.add_transaction(inventory) + if not success: + return False + + # TODO: should we keep the tx in the mempool if relaying failed? There is currently no mechanism that retries sending failed tx's + return wait_for(self.relay_directly(inventory)) + + async def relay_directly(self, inventory) -> bool: + relayed = False + + self.relay_cache.add(inventory) + + for node in self.nodes: + relayed |= await node.relay(inventory) + + return relayed + + def reset_for_test(self) -> None: + self.max_clients = settings.CONNECTED_PEER_MAX + self.min_clients = settings.CONNECTED_PEER_MIN + self.id = id(self) + self.mempool.reset() + self.nodes = [] # type: List[NeoNode] + self.queued_addresses = [] + self.bad_addresses = [] + self.known_addresses = [] + self.connection_queue = asyncio.Queue() + self.relay_cache.reset() + + """ + Internal helpers + """ + + async def _connect_to_node(self, address: str, quality_check=False, timeout=3) -> None: + host, port = address.split(':') + + try: + proto = partial(NeoProtocol, nodemanager=self, quality_check=quality_check) + connect_coro = self.loop.create_connection(proto, host, port, family=IP4_FAMILY) + # print(f"trying to connect to: {host}:{port}") + await asyncio.wait_for(connect_coro, timeout) + return + except asyncio.TimeoutError: + # print(f"{host}:{port} timed out") + pass + except OSError as e: + # print(f"{host}:{port} failed to connect for reason {e}") + pass + except asyncio.CancelledError: + pass + except Exception as e: + print("ohhh, some error we didn't expect happened. Please create a Github issue and share the following stacktrace so we can try to resolve it") + print("----------------[start of trace]----------------") + traceback.print_exc() + print("----------------[end of trace]----------------") + + with suppress(ValueError): + self.queued_addresses.remove(address) + + with suppress(ValueError): + self.known_addresses.remove(address) + + self.bad_addresses.append(address) + + async def shutdown(self) -> None: + print("Shutting down node manager...", end='') + self.shutting_down = True + + # shutdown all running tasks for this class + # to prevent requeueing when disconnecting nodes + logger.debug("stopping tasks...") + for t in self.tasks: + t.cancel() + await asyncio.gather(*self.tasks, return_exceptions=True) + + # finally disconnect all existing connections + # we need to create a new list to loop over, because `disconnect` removes items from self.nodes + to_disconnect = list(map(lambda n: n, self.nodes)) + disconnect_tasks = [] + logger.debug("disconnecting nodes...") + for n in to_disconnect: + disconnect_tasks.append(asyncio.create_task(n.disconnect())) + await asyncio.gather(*disconnect_tasks, return_exceptions=True) + + print("DONE") diff --git a/neo/Network/nodeweight.py b/neo/Network/nodeweight.py new file mode 100644 index 000000000..0cd4b0ea8 --- /dev/null +++ b/neo/Network/nodeweight.py @@ -0,0 +1,61 @@ +from datetime import datetime + + +class NodeWeight: + SPEED_RECORD_COUNT = 3 + SPEED_INIT_VALUE = 100 * 1024 ^ 2 # Start with a big speed of 100 MB/s + + REQUEST_TIME_RECORD_COUNT = 3 + + def __init__(self, nodeid): + self.id: int = nodeid + self.speed = [self.SPEED_INIT_VALUE] * self.SPEED_RECORD_COUNT + self.timeout_count = 0 + self.error_response_count = 0 + now = datetime.utcnow().timestamp() * 1000 # milliseconds + self.request_time = [now] * self.REQUEST_TIME_RECORD_COUNT + + def append_new_speed(self, speed) -> None: + # remove oldest + self.speed.pop(-1) + # add new + self.speed.insert(0, speed) + + def append_new_request_time(self) -> None: + self.request_time.pop(-1) + + now = datetime.utcnow().timestamp() * 1000 # milliseconds + self.request_time.insert(0, now) + + def _avg_speed(self) -> float: + return sum(self.speed) / self.SPEED_RECORD_COUNT + + def _avg_request_time(self) -> float: + avg_request_time = 0 + now = datetime.utcnow().timestamp() * 1000 # milliseconds + + for t in self.request_time: + avg_request_time += now - t + + avg_request_time = avg_request_time / self.REQUEST_TIME_RECORD_COUNT + return avg_request_time + + def weight(self): + # nodes with the highest speed and the longest time between querying for data have the highest weight + # and will be accessed first unless their error/timeout count is higher. This distributes load across nodes + weight = self._avg_speed() + self._avg_request_time() + + # punish errors and timeouts harder than slower speeds and more recent access + if self.error_response_count: + weight /= self.error_response_count + 1 # make sure we at least always divide by 2 + + if self.timeout_count: + weight /= self.timeout_count + 1 + return weight + + def __lt__(self, other): + return self.weight() < other.weight() + + def __repr__(self): + # return f"<{self.__class__.__name__} at {hex(id(self))}> w:{self.weight():.2f} r:{self.error_response_count} t:{self.timeout_count}" + return f"{self.id} {self._avg_speed():.2f} {self._avg_request_time():.2f} w:{self.weight():.2f} r:{self.error_response_count} t:{self.timeout_count}" diff --git a/neo/Network/p2pservice.py b/neo/Network/p2pservice.py new file mode 100644 index 000000000..f0bd0fea5 --- /dev/null +++ b/neo/Network/p2pservice.py @@ -0,0 +1,46 @@ +import asyncio +import logging + +from neo.Network.common.singleton import Singleton +from neo.Network.ledger import Ledger +from neo.Network.message import Message +from neo.Network.nodemanager import NodeManager +from neo.Network.syncmanager import SyncManager +from neo.Settings import settings + +from contextlib import suppress + + +class NetworkService(Singleton): + def init(self): + self.loop = asyncio.get_event_loop() + self.syncmgr = None + self.nodemgr = None + + self.nodemgr_task = None + + async def start(self): + Message._magic = settings.MAGIC + self.nodemgr = NodeManager() + self.syncmgr = SyncManager(self.nodemgr) + ledger = Ledger() + self.syncmgr.ledger = ledger + + logging.getLogger("asyncio").setLevel(logging.DEBUG) + self.loop.set_debug(False) + self.nodemgr_task = self.loop.create_task(self.nodemgr.start()) + self.loop.create_task(self.syncmgr.start()) + + async def shutdown(self): + if self.nodemgr_task and self.nodemgr_task.done(): + # starting nodemanager can fail if a port is in use, we need to retrieve and mute this exception on shutdown + with suppress(SystemExit): + self.nodemgr_task.exception() + + with suppress(asyncio.CancelledError): + if self.syncmgr: + await self.syncmgr.shutdown() + + with suppress(asyncio.CancelledError): + if self.nodemgr: + await self.nodemgr.shutdown() diff --git a/neo/Network/payloads/__init__.py b/neo/Network/payloads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/neo/Network/payloads/addr.py b/neo/Network/payloads/addr.py new file mode 100644 index 000000000..26ed4a990 --- /dev/null +++ b/neo/Network/payloads/addr.py @@ -0,0 +1,40 @@ +from neo.Network.payloads.base import BasePayload +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.io.binary_reader import BinaryReader +from typing import List, Union +from neo.Network.payloads.networkaddress import NetworkAddressWithTime + + +class AddrPayload(BasePayload): + def __init__(self, addresses: List[NetworkAddressWithTime] = None): + self.addresses = addresses if addresses else [] + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_var_int(len(self.addresses)) + for address in self.addresses: + address.serialize(writer) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + addr_list_len = reader.read_var_int() + for i in range(0, addr_list_len): + nawt = NetworkAddressWithTime() + nawt.deserialize(reader) + self.addresses.append(nawt) + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]): + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + addr_payload = cls() + addr_payload.deserialize(br) + br.cleanup() + return addr_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/base.py b/neo/Network/payloads/base.py new file mode 100644 index 000000000..3af930a13 --- /dev/null +++ b/neo/Network/payloads/base.py @@ -0,0 +1,13 @@ +from neo.Network.core.mixin.serializable import SerializableMixin + + +class BasePayload(SerializableMixin): + + def serialize(self, writer) -> None: + pass + + def deserialize(self, reader) -> None: + pass + + def to_array(self) -> bytearray: + pass diff --git a/neo/Network/payloads/block.py b/neo/Network/payloads/block.py new file mode 100644 index 000000000..e4ac7ead5 --- /dev/null +++ b/neo/Network/payloads/block.py @@ -0,0 +1,72 @@ +from neo.Network.core.blockbase import BlockBase +from neo.Network.core.header import Header +from neo.Network.core.io.binary_reader import BinaryReader +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.uint256 import UInt256 +from typing import Union + + +class Block(BlockBase): + def __init__(self, prev_hash, timestamp, index, consensus_data, next_consensus, witness): + version = 0 + temp_merkleroot = UInt256.zero() + super(Block, self).__init__(version, prev_hash, temp_merkleroot, timestamp, index, consensus_data, next_consensus, witness) + self.prev_hash = prev_hash + self.timestamp = timestamp + self.index = index + self.consensus_data = consensus_data + self.next_consensus = next_consensus + self.witness = witness + self.transactions = [] # hardcoded to empty as we will not deserialize these + + # not part of the official Block implementation, just useful info for internal usage + self._tx_count = 0 + self._size = 0 + + def header(self) -> Header: + return Header(self.prev_hash, self.merkle_root, self.timestamp, self.index, self.consensus_data, + self.next_consensus, self.witness) + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + super(Block, self).serialize(writer) + + len_transactions = len(self.transactions) + if len_transactions == 0: + writer.write_uint8(0) + else: + writer.write_var_int(len_transactions) + for tx in self.transactions: + tx.serialize(writer) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + super(Block, self).deserialize(reader) + + # ignore reading actual transactions, but we can determine the count + self._tx_count = reader.read_var_int() + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]) -> 'Block': + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + block = cls(None, None, None, None, None, None) + try: + block.deserialize(br) + # at this point we do not fully support all classes that can build up a block (e.g. Transactions) + # the normal size calculation would request each class for its size and sum them up + # we can shortcut this calculation in the absence of those classes by just determining the amount of bytes + # in the payload + block._size = len(data_stream) + except ValueError: + return None + finally: + br.cleanup() + return block + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/getblocks.py b/neo/Network/payloads/getblocks.py new file mode 100644 index 000000000..a51372379 --- /dev/null +++ b/neo/Network/payloads/getblocks.py @@ -0,0 +1,42 @@ +from neo.Network.payloads.base import BasePayload +from neo.Network.core.uint256 import UInt256 +from typing import TYPE_CHECKING, Union +from neo.Network.core.io.binary_writer import BinaryWriter + +if TYPE_CHECKING: + from neo.Network.core import BinaryReader + + +class GetBlocksPayload(BasePayload): + def __init__(self, start: UInt256, stop: UInt256 = None): + self.hash_start = [start] + self.hash_stop = stop if stop else UInt256.zero() + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + length = len(self.hash_start) + writer.write_var_int(length) + for hash in self.hash_start: + writer.write_uint256(hash) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + length = reader.read_var_int() + self.hash_start = list(map(reader.read_uint256(), range(length))) + self.hash_stop = reader.read_uint256() + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]): + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + block_payload = cls() + block_payload.deserialize(br) + br.cleanup() + return block_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/headers.py b/neo/Network/payloads/headers.py new file mode 100644 index 000000000..b8d2a3bc5 --- /dev/null +++ b/neo/Network/payloads/headers.py @@ -0,0 +1,51 @@ +from neo.Network.payloads.base import BasePayload +from typing import TYPE_CHECKING, Optional, Union, List +from neo.Network.core.header import Header +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.io.binary_reader import BinaryReader + +if TYPE_CHECKING: + from neo.Network.core import BinaryReader + + +class HeadersPayload(BasePayload): + def __init__(self, headers: Optional[List[Header]] = None): + self.headers = headers if headers else [] + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + len_headers = len(self.headers) + if len_headers == 0: + writer.write_uint8(0) + else: + writer.write_var_int(len_headers) + for header in self.headers: + header.serialize(writer) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object + + Raises: + DeserializationError: if deserialization fails + """ + arr_length = reader.read_var_int() + for i in range(arr_length): + h = Header(None, None, None, None, None, None, None) + h.deserialize(reader) + self.headers.append(h) + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]) -> 'HeadersPayload': + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + headers_payload = cls() + headers_payload.deserialize(br) + br.cleanup() + return headers_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/inventory.py b/neo/Network/payloads/inventory.py new file mode 100644 index 000000000..e97156401 --- /dev/null +++ b/neo/Network/payloads/inventory.py @@ -0,0 +1,58 @@ +from neo.Network.payloads.base import BasePayload +from enum import Enum +from typing import Union, List +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.io.binary_reader import BinaryReader +from neo.Network.core.uint256 import UInt256 + + +class InventoryType(Enum): + tx = 0x01 + block = 0x02 + consensus = 0xe0 + + +class InventoryPayload(BasePayload): + + def __init__(self, type: InventoryType = None, hashes: List[UInt256] = None): + self.type = type + self.hashes = hashes if hashes else [] + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_uint8(self.type.value) + writer.write_var_int(len(self.hashes)) + for h in self.hashes: # type: UInt256 + writer.write_bytes(h.to_array()) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self.type = InventoryType(reader.read_uint8()) + self.hashes = [] + hash_list_count = reader.read_var_int() + + try: + for i in range(0, hash_list_count): + self.hashes.append(UInt256(data=reader.read_bytes(32))) + except ValueError: + raise ValueError("Invalid hashes data") + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]): + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + inv_payload = cls() + try: + inv_payload.deserialize(br) + except ValueError: + return None + finally: + br.cleanup() + return inv_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/networkaddress.py b/neo/Network/payloads/networkaddress.py new file mode 100644 index 000000000..dce0918b3 --- /dev/null +++ b/neo/Network/payloads/networkaddress.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING +from datetime import datetime +from neo.Network.core.size import Size as s +from neo.Network.payloads.base import BasePayload + +if TYPE_CHECKING: + from neo.Network.core import BinaryReader + from neo.Network.core.io.binary_writer import BinaryWriter + + +class NetworkAddressWithTime(BasePayload): + NODE_NETWORK = 1 + + def __init__(self, address: str = None, port: int = None, services: int = 0, timestamp: int = None) -> None: + """ Create an instance. """ + if timestamp is None: + self.timestamp = int(datetime.utcnow().timestamp()) + else: + self.timestamp = timestamp + + self.address = address + self.port = port + self.services = services + + @property + def size(self) -> int: + """ Get the total size in bytes of the object. """ + return s.uint32 + s.uint64 + 16 + s.uint16 + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_uint32(self.timestamp) + writer.write_uint64(self.services) + # turn ip address into bytes + octets = bytearray(map(lambda oct: int(oct), self.address.split('.'))) + # pad to fixed length 16 + octets += bytearray(12) + # and finally write to stream + writer.write_bytes(octets) + + writer.write_uint16(self.port, endian='>') + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self.timestamp = reader.read_uint32() + self.services = reader.read_uint64() + full_address_bytes = bytearray(reader.read_fixed_string(16)) + ip_bytes = full_address_bytes[-4:] + self.address = '.'.join(map(lambda b: str(b), ip_bytes)) + self.port = reader.read_uint16(endian='>') + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data + + def __str__(self) -> str: + """ + Get the string representation of the network address. + + Returns: + str: address:port + """ + return f"{self.address}:{self.port}" diff --git a/neo/Network/payloads/ping.py b/neo/Network/payloads/ping.py new file mode 100644 index 000000000..bafb4d73b --- /dev/null +++ b/neo/Network/payloads/ping.py @@ -0,0 +1,54 @@ +from typing import Union +from neo.Network.core.size import Size as s +from neo.Network.payloads.base import BasePayload +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.io.binary_reader import BinaryReader +from datetime import datetime +from random import randint + + +class PingPayload(BasePayload): + def __init__(self, height: int = 0) -> None: + self.current_height = height + self.timestamp = int(datetime.utcnow().timestamp()) + self.nonce = randint(100, 10000) + + def __len__(self): + return self.size() + + def size(self) -> int: + """ + Get the total size in bytes of the object. + + Returns: + int: size. + """ + return s.uint32 + s.uint32 + s.uint32 + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_uint32(self.current_height) + writer.write_uint32(self.timestamp) + writer.write_uint32(self.nonce) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self.current_height = reader.read_uint32() + self.timestamp = reader.read_uint32() + self.nonce = reader.read_uint32() + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]) -> 'PingPayload': + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + ping_payload = cls() + ping_payload.deserialize(br) + br.cleanup() + return ping_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/payloads/version.py b/neo/Network/payloads/version.py new file mode 100644 index 000000000..f6e917c41 --- /dev/null +++ b/neo/Network/payloads/version.py @@ -0,0 +1,81 @@ +import datetime +import random +from typing import Union +from neo.Network.core.size import Size as s +from neo.Network.core.size import GetVarSize +from neo.Network.payloads.base import BasePayload +from neo.Network.payloads.networkaddress import NetworkAddressWithTime +from neo.Network.core.io.binary_writer import BinaryWriter +from neo.Network.core.io.binary_reader import BinaryReader + + +class VersionPayload(BasePayload): + + def __init__(self, port: int = None, nonce: int = None, userAgent: str = None) -> None: + """ + Create an instance. + + Args: + port: + nonce: + userAgent: client user agent string. + """ + # if port and nonce and userAgent: + self.port = port + self.version = 0 + self.services = NetworkAddressWithTime.NODE_NETWORK + self.timestamp = int(datetime.datetime.utcnow().timestamp()) + self.nonce = nonce if nonce else random.randint(0, 10000) + self.user_agent = userAgent if userAgent else "" + self.start_height = 0 # TODO: update once blockchain class is available + self.relay = True + + def __len__(self): + return self.size() + + def size(self) -> int: + """ + Get the total size in bytes of the object. + + Returns: + int: size. + """ + return s.uint32 + s.uint64 + s.uint32 + s.uint16 + s.uint32 + GetVarSize(self.user_agent) + s.uint32 + s.uint8 + + def serialize(self, writer: 'BinaryWriter') -> None: + """ Serialize object. """ + writer.write_uint32(self.version) + writer.write_uint64(self.services) + writer.write_uint32(self.timestamp) + writer.write_uint16(self.port) + writer.write_uint32(self.nonce) + writer.write_var_string(self.user_agent) + writer.write_uint32(self.start_height) + writer.write_bool(self.relay) + + def deserialize(self, reader: 'BinaryReader') -> None: + """ Deserialize object. """ + self.version = reader.read_uint32() + self.services = reader.read_uint64() + self.timestamp = reader.read_uint32() + self.port = reader.read_uint16() + self.nonce = reader.read_uint32() + self.user_agent = reader.read_var_string() + self.start_height = reader.read_uint32() + self.relay = reader.read_bool() + + @classmethod + def deserialize_from_bytes(cls, data_stream: Union[bytes, bytearray]): + """ Deserialize object from a byte array. """ + br = BinaryReader(stream=data_stream) + version_payload = cls() + version_payload.deserialize(br) + br.cleanup() + return version_payload + + def to_array(self) -> bytearray: + writer = BinaryWriter(stream=bytearray()) + self.serialize(writer) + data = bytearray(writer._stream.getvalue()) + writer.cleanup() + return data diff --git a/neo/Network/protocol.py b/neo/Network/protocol.py new file mode 100644 index 000000000..f3ed01a0e --- /dev/null +++ b/neo/Network/protocol.py @@ -0,0 +1,101 @@ +import asyncio +import struct +from typing import Optional +from neo.Network.node import NeoNode +from neo.Network.message import Message +from asyncio.streams import StreamReader, StreamReaderProtocol, StreamWriter +from asyncio import events +from neo.logging import log_manager + +logger = log_manager.getLogger('network') + + +class NeoProtocol(StreamReaderProtocol): + def __init__(self, *args, quality_check=False, **kwargs): + """ + + Args: + *args: + quality_check (bool): there are times when we only establish a connection to check the quality of the node/address + **kwargs: + """ + self._stream_reader = StreamReader() + self._stream_writer = None + nodemanager = kwargs.pop('nodemanager') + self.client = NeoNode(self, nodemanager, quality_check) + self._loop = events.get_event_loop() + super().__init__(self._stream_reader) + + def connection_made(self, transport: asyncio.transports.BaseTransport) -> None: + super().connection_made(transport) + self._stream_writer = StreamWriter(transport, self, self._stream_reader, self._loop) + + if self.client: + asyncio.create_task(self.client.connection_made(transport)) + + def connection_lost(self, exc: Optional[Exception] = None) -> None: + if self.client: + task = asyncio.create_task(self.client.connection_lost(exc)) + task.add_done_callback(lambda args: super(NeoProtocol, self).connection_lost(exc)) + else: + super().connection_lost(exc) + + def eof_received(self) -> bool: + self._stream_reader.feed_eof() + + self.connection_lost() + return True + # False == Do not keep connection open, this makes sure that `connection_lost` gets called. + # return False + + async def send_message(self, message: Message) -> None: + try: + self._stream_writer.write(message.to_array()) + await self._stream_writer.drain() + except ConnectionResetError: + # print("connection reset") + self.connection_lost(ConnectionResetError) + except ConnectionError: + # print("connection error") + self.connection_lost(ConnectionError) + except asyncio.CancelledError: + # print("task cancelled, closing connection") + self.connection_lost(asyncio.CancelledError) + except Exception as e: + # print(f"***** woah what happened here?! {traceback.format_exc()}") + self.connection_lost() + + async def read_message(self, timeout: int = 30) -> Message: + if timeout == 0: + # avoid memleak. See: https://bugs.python.org/issue37042 + timeout = None + + async def _read(): + try: + message_header = await self._stream_reader.readexactly(24) + magic, command, payload_length, checksum = struct.unpack('I 12s I I', + message_header) # uint32, 12byte-string, uint32, uint32 + + payload_data = await self._stream_reader.readexactly(payload_length) + payload, = struct.unpack('{}s'.format(payload_length), payload_data) + + except Exception: + # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up + self.client.disconnecting = True + return None + + m = Message(magic, command.rstrip(b'\x00').decode('utf-8'), payload) + + if checksum != m.get_checksum(payload): + logger.debug("Message checksum incorrect") + return None + else: + return m + try: + return await asyncio.wait_for(_read(), timeout) + except Exception: + return None + + def disconnect(self) -> None: + if self._stream_writer: + self._stream_writer.close() diff --git a/neo/Network/relaycache.py b/neo/Network/relaycache.py new file mode 100644 index 000000000..b41d74f64 --- /dev/null +++ b/neo/Network/relaycache.py @@ -0,0 +1,38 @@ +from contextlib import suppress + +from neo.Core.Block import Block as OrigBlock +from neo.Network.common import msgrouter +from neo.Network.common.singleton import Singleton +from neo.logging import log_manager + +logger = log_manager.getLogger('network') + + +# TODO: how can we tell if our item is rejected by consensus nodes other than not being processed after x time? cache can grow infinite in size + +class RelayCache(Singleton): + def init(self): + self.cache = dict() # uint256 : tx/block/consensus data + msgrouter.on_block_persisted += self.update_cache_for_block_persist + + def add(self, old_style_inventory) -> None: + # TODO: make this UInt256 instead of the string identifier once we've fully moved to the new implementation + self.cache.update({old_style_inventory.Hash.ToString(): old_style_inventory}) + + def get_and_remove(self, new_style_hash): + try: + return self.cache.pop(new_style_hash.to_string()) + except KeyError: + return None + + def try_get(self, new_style_hash): + return self.cache.get(new_style_hash.to_string(), None) + + def update_cache_for_block_persist(self, orig_block: OrigBlock) -> None: + for tx in orig_block.Transactions: + with suppress(KeyError): + self.cache.pop(tx.Hash.ToString()) + logger.debug(f"Found {tx.Hash} in last persisted block. Removing from relay cache") + + def reset(self) -> None: + self.cache = dict() diff --git a/neo/Network/requestinfo.py b/neo/Network/requestinfo.py new file mode 100644 index 000000000..50b720b72 --- /dev/null +++ b/neo/Network/requestinfo.py @@ -0,0 +1,21 @@ +from neo.Network.flightinfo import FlightInfo + + +class RequestInfo: + def __init__(self, height): + self.height: int = height + self.failed_nodes: dict = dict() # nodeId: timeout time + self.failed_total: int = 0 + self.flights: dict = dict() # nodeId:FlightInfo + self.last_used_node = None + + def add_new_flight(self, flight_info: FlightInfo) -> None: + self.flights[flight_info.node_id] = flight_info + self.last_used_node = flight_info.node_id + + def most_recent_flight(self) -> FlightInfo: + return self.flights[self.last_used_node] + + def mark_failed_node(self, node_id) -> None: + self.failed_nodes[node_id] = self.failed_nodes.get(node_id, 0) + 1 + self.failed_total += 1 diff --git a/neo/Network/syncmanager.py b/neo/Network/syncmanager.py new file mode 100644 index 000000000..ffff96a17 --- /dev/null +++ b/neo/Network/syncmanager.py @@ -0,0 +1,442 @@ +import asyncio +import traceback +from datetime import datetime +from neo.Network.core.header import Header +from typing import TYPE_CHECKING, List +from neo.Network.flightinfo import FlightInfo +from neo.Network.requestinfo import RequestInfo +from neo.Network.payloads.inventory import InventoryType +from neo.Network.common import msgrouter +from neo.Network.common.singleton import Singleton +from contextlib import suppress +from neo.Network.core.uint256 import UInt256 + +from neo.logging import log_manager + +logger = log_manager.getLogger('syncmanager') +# log_manager.config_stdio([('syncmanager', 10)]) + +if TYPE_CHECKING: + from neo.Network.nodemanager import NodeManager + from neo.Network.payloads import Block + + +class SyncManager(Singleton): + HEADER_MAX_LOOK_AHEAD = 6000 + HEADER_REQUEST_TIMEOUT = 5 + + BLOCK_MAX_CACHE_SIZE = 500 + BLOCK_NETWORK_REQ_LIMIT = 500 + BLOCK_REQUEST_TIMEOUT = 5 + + def init(self, nodemgr: 'NodeManager'): + self.nodemgr = nodemgr + self.controller = None + self.block_requests = dict() # header_hash:RequestInfo + self.header_request = None # type: RequestInfo + self.ledger = None + self.block_cache = [] + self.header_cache = [] + self.raw_block_cache = [] + self.is_persisting_blocks = False + self.is_persisting_headers = False + self.keep_running = True + self.service_task = None + self.persist_task = None + self.health_task = None + + msgrouter.on_headers += self.on_headers_received + msgrouter.on_block += self.on_block_received + + async def start(self) -> None: + while not self.nodemgr.running: + await asyncio.sleep(0.1) + self.service_task = asyncio.create_task(self.run_service()) + self.health_task = asyncio.create_task(self.block_health()) + + async def shutdown(self): + print("Shutting down sync manager...", end='') + self.keep_running = False + self.block_cache = [] + self.health_task.cancel() + shutdown_tasks = [self.service_task, self.health_task] + if self.persist_task: + shutdown_tasks.append(self.persist_task) + await asyncio.gather(*shutdown_tasks, return_exceptions=True) + + print("DONE") + + async def block_health(self): + # TODO: move this to nodemanager, once the network in general supports ping/pong + # we can then make smarter choices by looking at individual nodes advancing or not and dropping just those + error_counter = 0 + last_height = await self.ledger.cur_block_height() + while self.keep_running: + await asyncio.sleep(15) + cur_height = await self.ledger.cur_block_height() + if cur_height == last_height: + error_counter += 1 + if error_counter == 3: + to_disconnect = list(map(lambda n: n, self.nodemgr.nodes)) + logger.debug(f"Block height not advancing. Replacing nodes: {to_disconnect}") + for n in to_disconnect: + await self.nodemgr.replace_node(n) + else: + error_counter = 0 + + last_height = cur_height + + async def run_service(self): + while self.keep_running: + await self.check_timeout() + await self.sync() + await asyncio.sleep(1) + + async def sync(self) -> None: + await self.sync_header() + await self.sync_block() + await self.persist_headers() + if not self.is_persisting_blocks: + self.persist_task = asyncio.create_task(self.persist_blocks()) + + async def sync_header(self) -> None: + if self.header_request: + return + + cur_header_height = await self.ledger.cur_header_height() + cur_block_height = await self.ledger.cur_block_height() + if cur_header_height - cur_block_height >= self.HEADER_MAX_LOOK_AHEAD: + return + + node = self.nodemgr.get_next_node(cur_header_height + 1) + if not node: + # No connected nodes or no nodes with our height. We'll wait for node manager to resolve this + # or for the nodes to increase their height on the next produced block + return + + self.header_request = RequestInfo(cur_header_height + 1) + self.header_request.add_new_flight(FlightInfo(node.nodeid, cur_header_height + 1)) + + cur_header_hash = await self.ledger.header_hash_by_height(cur_header_height) + await node.get_headers(hash_start=cur_header_hash) + + logger.debug(f"Requested headers starting at {cur_header_height + 1} from node {node.nodeid_human}") + node.nodeweight.append_new_request_time() + + async def persist_headers(self): + self.is_persisting_headers = True + if len(self.header_cache) > 0: + while self.keep_running: + try: + headers = self.header_cache.pop(0) + try: + await self.ledger.add_headers(headers) + except Exception as e: + print(traceback.format_exc()) + await asyncio.sleep(0) + except IndexError: + # cache empty + break + + # reset header_request such that the a new header sync task can be added + self.header_request = None + logger.debug("Finished processing headers") + + self.is_persisting_headers = False + + async def sync_block(self) -> None: + # to simplify syncing, don't ask for more data if we still have requests in flight + if len(self.block_requests) > 0: + return + + # the block cache might not have been fully processed, so we want to avoid asking for data we actually already have + best_block_height = await self.get_best_stored_block_height() + cur_header_height = await self.ledger.cur_header_height() + blocks_to_fetch = cur_header_height - best_block_height + if blocks_to_fetch <= 0: + return + + block_cache_space = self.BLOCK_MAX_CACHE_SIZE - len(self.block_cache) + if block_cache_space <= 0: + return + + if blocks_to_fetch > block_cache_space or blocks_to_fetch > self.BLOCK_NETWORK_REQ_LIMIT: + blocks_to_fetch = min(block_cache_space, self.BLOCK_NETWORK_REQ_LIMIT) + + try: + best_node_height = max(map(lambda node: node.best_height, self.nodemgr.nodes)) + except ValueError: + # if the node list is empty max() fails on an empty list + return + + node = self.nodemgr.get_next_node(best_node_height) + if not node: + # no nodes with our desired height. We'll wait for node manager to resolve this + # or for the nodes to increase their height on the next produced block + return + + hashes = [] + endheight = None + for i in range(1, blocks_to_fetch + 1): + next_block_height = best_block_height + i + if self.is_in_blockcache(next_block_height): + continue + + if next_block_height > best_node_height: + break + + next_header_hash = await self.ledger.header_hash_by_height(next_block_height) + if next_header_hash == UInt256.zero(): + # we do not have enough headers to fill the block cache. That's fine, just return + break + + endheight = next_block_height + hashes.append(next_header_hash) + self.add_block_flight_info(node.nodeid, next_block_height, next_header_hash) + + if len(hashes) > 0: + logger.debug(f"Asking for blocks {best_block_height + 1} - {endheight} from {node.nodeid_human}") + await node.get_data(InventoryType.block, hashes) + node.nodeweight.append_new_request_time() + + async def persist_blocks(self) -> None: + self.is_persisting_blocks = True + while self.keep_running: + try: + b = self.block_cache.pop(0) + raw_b = self.raw_block_cache.pop(0) + await self.ledger.add_block(raw_b) + await asyncio.sleep(0.001) + except IndexError: + # cache empty + break + self.is_persisting_blocks = False + + async def check_timeout(self) -> None: + task1 = asyncio.create_task(self.check_header_timeout()) + task2 = asyncio.create_task(self.check_block_timeout()) + try: + await asyncio.gather(task1, task2) + except Exception: + logger.debug(traceback.format_exc()) + + async def check_header_timeout(self) -> None: + if not self.header_request: + # no data requests outstanding + return + + last_flight_info = self.header_request.most_recent_flight() + + now = datetime.utcnow().timestamp() + delta = now - last_flight_info.start_time + if delta < self.HEADER_REQUEST_TIMEOUT: + # we're still good on time + return + + node = self.nodemgr.get_node_by_nodeid(last_flight_info.node_id) + if node: + logger.debug(f"Header timeout limit exceeded by {delta - self.HEADER_REQUEST_TIMEOUT:.2f}s for node {node.nodeid_human}") + + cur_header_height = await self.ledger.cur_header_height() + if last_flight_info.height <= cur_header_height: + # it has already come in in the mean time + # reset so sync_header will request new headers + self.header_request = None + return + + # punish node that is causing header_timeout and retry using another node + self.header_request.mark_failed_node(last_flight_info.node_id) + await self.nodemgr.add_node_timeout_count(last_flight_info.node_id) + + # retry with a new node + node = self.nodemgr.get_node_with_min_failed_time(self.header_request) + if node is None: + # only happens if there are no nodes that have data matching our needed height + self.header_request = None + return + + hash = await self.ledger.header_hash_by_height(last_flight_info.height - 1) + logger.debug(f"Retry requesting headers starting at {last_flight_info.height} from new node {node.nodeid_human}") + await node.get_headers(hash_start=hash) + + # restart start_time of flight info or else we'll timeout too fast for the next node + self.header_request.add_new_flight(FlightInfo(node.nodeid, last_flight_info.height)) + node.nodeweight.append_new_request_time() + + async def check_block_timeout(self) -> None: + if len(self.block_requests) == 0: + # no data requests outstanding + return + + now = datetime.utcnow().timestamp() + block_timeout_flights = dict() + + # test for timeout + for block_hash, request_info in self.block_requests.items(): # type: _, RequestInfo + flight_info = request_info.most_recent_flight() + if now - flight_info.start_time > self.BLOCK_REQUEST_TIMEOUT: + block_timeout_flights[block_hash] = flight_info + + if len(block_timeout_flights) == 0: + # no timeouts + return + + # 1) we first filter out invalid requests as some might have come in by now + # 2) for each block_sync cycle we requested blocks in batches of max 500 per node, now when resending we try to + # create another batch + # 3) Blocks arrive one by one in 'inv' messages. In the block_sync cycle we created a FlightInfo object per + # requested block such that we can determine speed among others. If one block in a request times out all + # others for the same request will of course do as well (as they arrive in a linear fashion from the same node). + # As such we only want to tag the individual node once (per request) for being slower than our timeout threshold not 500 times. + remaining_requests = [] + nodes_to_tag_for_timeout = set() + nodes_to_mark_failed = dict() + + best_stored_block_height = await self.get_best_stored_block_height() + + for block_hash, fi in block_timeout_flights.items(): # type: _, FlightInfo + nodes_to_tag_for_timeout.add(fi.node_id) + + try: + request_info = self.block_requests[block_hash] + except KeyError: + # means on_block_received popped it of the list + # we don't have to retry for data anymore + continue + + if fi.height <= best_stored_block_height: + with suppress(KeyError): + self.block_requests.pop(block_hash) + continue + + nodes_to_mark_failed[request_info] = fi.node_id + remaining_requests.append((block_hash, fi.height, request_info)) + + for nodeid in nodes_to_tag_for_timeout: + await self.nodemgr.add_node_timeout_count(nodeid) + + for request_info, node_id in nodes_to_mark_failed.items(): + request_info.mark_failed_node(node_id) + + # for the remaining requests that need to be queued again, we create new FlightInfo objects that use a new node + # and ask them in a single batch from that new node. + hashes = [] + if len(remaining_requests) > 0: + # retry the batch with a new node + ri_first = remaining_requests[0][2] + ri_last = remaining_requests[-1][2] + + # using `ri_last` because this has the highest block height and we want a node that supports that + node = self.nodemgr.get_node_with_min_failed_time(ri_last) + if not node: + return + + for block_hash, height, ri in remaining_requests: # type: _, int, RequestInfo + ri.add_new_flight(FlightInfo(node.nodeid, height)) + + hashes.append(block_hash) + + if len(hashes) > 0: + logger.debug(f"Block time out for blocks {ri_first.height} - {ri_last.height}. Trying again using new node {node.nodeid_human} {hashes[0]}") + await node.get_data(InventoryType.block, hashes) + node.nodeweight.append_new_request_time() + + async def on_headers_received(self, from_nodeid, headers: List[Header]) -> int: + if len(headers) == 0: + return -1 + + if self.header_request is None: + return -2 + + height = headers[0].index + if height != self.header_request.height: + # received headers we did not ask for + return -3 + logger.debug(f"Headers received {headers[0].index} - {headers[-1].index}") + + if headers in self.header_cache: + return -4 + + cur_header_height = await self.ledger.cur_header_height() + if height <= cur_header_height: + return -5 + + self.header_cache.append(headers) + + return 1 + + async def on_block_received(self, from_nodeid, block: 'Block', raw_block) -> None: + # TODO: take out raw_block and raw_block_cache once we can serialize a full block + # print(f"{block.index} {block.hash} received") + + next_header_height = await self.ledger.cur_header_height() + 1 + if block.index > next_header_height: + return + + cur_block_height = await self.ledger.cur_block_height() + if block.index <= cur_block_height: + return + + try: + ri = self.block_requests.pop(block.hash) # type: RequestInfo + fi = ri.flights.pop(from_nodeid) # type: FlightInfo + now = datetime.utcnow().timestamp() + delta_time = now - fi.start_time + speed = (block._size / 1024) / delta_time # KB/s + + node = self.nodemgr.get_node_by_nodeid(fi.node_id) + if node: + node.nodeweight.append_new_speed(speed) + except KeyError: + # it's a block we did not ask for + # this can either be caused by rogue actors sending bad blocks + # or as a reply to our `get_data` on a broadcasted `inv` message by the node. + # (neo-cli nodes broadcast `inv` messages with their latest hash, we currently need to do a `get_data` + # and receive the full block to know what their best height is as we have no other mechanism (yet)) + # TODO: remove once the network all start using neo-cli 2.10.1 or above which support ping/pong for height + sync_distance = block.index - cur_block_height + if sync_distance != 1: + return + # but if the distance is 1 we're in sync so we add the block anyway + # to avoid having the `sync_block` task request the same data again + # this is also necessary for neo-cli nodes because they maintain a TaskSession and refuse to send recently requested data + + if not self.is_in_blockcache(block.index) and self.keep_running: + self.block_cache.append(block) + self.raw_block_cache.append(raw_block) + + async def get_best_stored_block_height(self) -> int: + """ + Helper to return the highest block in our possession (either in ledger or in block_cache) + """ + best_block_cache_height = 0 + if len(self.block_cache) > 0: + best_block_cache_height = self.block_cache[-1].index + + ledger_height = await self.ledger.cur_block_height() + + return max(ledger_height, best_block_cache_height) + + def is_in_blockcache(self, block_height: int) -> bool: + for b in self.block_cache: + if b.index == block_height: + return True + else: + return False + + def add_block_flight_info(self, nodeid, height, header_hash) -> None: + request_info = self.block_requests.get(header_hash, None) # type: RequestInfo + + if request_info is None: + # no outstanding requests for this particular hash, so we create it + req = RequestInfo(height) + req.add_new_flight(FlightInfo(nodeid, height)) + self.block_requests[header_hash] = req + else: + request_info.flights.update({nodeid: FlightInfo(nodeid, height)}) + + def reset(self) -> None: + self.header_request = None + self.block_requests = dict() + self.block_cache = [] + self.raw_block_cache = [] diff --git a/neo/Network/test_address.py b/neo/Network/test_address.py deleted file mode 100644 index 354f0479a..000000000 --- a/neo/Network/test_address.py +++ /dev/null @@ -1,88 +0,0 @@ -from neo.Utils.NeoTestCase import NeoTestCase -from neo.Network.address import Address -from datetime import datetime - - -class AddressTest(NeoTestCase): - def test_init_simple(self): - host = '127.0.0.1:80' - a = Address(host) - self.assertEqual(0, a.last_connection) - self.assertEqual(a.address, host) - - # test custom 'last_connection_to' - b = Address(host, 123) - self.assertEqual(123, b.last_connection) - - def test_now_helper(self): - n = Address.Now() - delta = datetime.now().utcnow().timestamp() - n - self.assertTrue(delta < 2) - - def test_equality(self): - """ - Only the host:port matters in equality - """ - a = Address('127.0.0.1:80', last_connection_to=0) - b = Address('127.0.0.1:80', last_connection_to=0) - c = Address('127.0.0.1:99', last_connection_to=0) - self.assertEqual(a, b) - self.assertNotEqual(a, c) - - # last connected does not influence equality - b.last_connection = 123 - self.assertEqual(a, b) - - # different port does change equality - b.address = "127.0.0.1:99" - self.assertNotEqual(a, b) - - # test diff types - self.assertNotEqual(int(1), a) - self.assertNotEqual("127.0.0.1:80", a) - - def test_repr_and_str(self): - host = '127.0.0.1:80' - a = Address(host, last_connection_to=0) - self.assertEqual(host, str(a)) - - x = repr(a) - self.assertIn("Address", x) - self.assertIn(host, x) - - def test_split(self): - a = Address('127.0.0.1:80') - host, port = a.split(':') - self.assertEqual(host, '127.0.0.1') - self.assertEqual(port, '80') - - host, port = a.rsplit(':', maxsplit=1) - self.assertEqual(host, '127.0.0.1') - self.assertEqual(port, '80') - - def test_str_formatting(self): - a = Address('127.0.0.1:80') - expected = " 127.0.0.1:80" - out = f"{a:>15}" - self.assertEqual(expected, out) - - def test_list_lookup(self): - a = Address('127.0.0.1:80') - b = Address('127.0.0.2:80') - c = Address('127.0.0.1:80') - d = Address('127.0.0.1:99') - - z = [a, b] - self.assertTrue(a in z) - self.assertTrue(b in z) - # duplicate check, equals to 'a' - self.assertTrue(c in z) - self.assertFalse(d in z) - - def test_dictionary_lookup(self): - """for __hash__""" - a = Address('127.0.0.1:80') - b = Address('127.0.0.2:80') - addr = {a: 1, b: 2} - self.assertEqual(addr[a], 1) - self.assertEqual(addr[b], 2) diff --git a/neo/Network/test_network.py b/neo/Network/test_network.py deleted file mode 100644 index 1c5c56505..000000000 --- a/neo/Network/test_network.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Test handling of a node(s) disconnecting. Reasons can be: -- we shutdown node due bad responsiveness -- we shutdown node because we shutdown -- node shuts us down for unknown reason -- node shuts us down because they shutdown -""" - -from twisted.trial import unittest as twisted_unittest -from twisted.internet.address import IPv4Address -from twisted.internet import error -from twisted.test import proto_helpers -from twisted.python import failure - -from neo.Network.NodeLeader import NodeLeader -from neo.Network.address import Address -from neo.Network.Utils import TestTransportEndpoint -from neo.Network.NeoNode import NeoNode, HEARTBEAT_BLOCKS -from neo.Utils.NeoTestCase import NeoTestCase - - -class NetworkConnectionLostTests(twisted_unittest.TestCase, NeoTestCase): - def setUp(self): - self.node = None - self.leader = NodeLeader.Instance() - - host, port = '127.0.0.1', 8080 - self.addr = Address(f"{host}:{port}") - - # we use a helper class such that we do not have to setup a real TCP connection - peerAddress = IPv4Address('TCP', host, port) - self.endpoint = TestTransportEndpoint(self.leader.reactor, str(self.addr), proto_helpers.StringTransportWithDisconnection(peerAddress=peerAddress)) - - # store our deferred so we can add callbacks - self.d = self.leader.SetupConnection(self.addr, self.endpoint) - # make sure we create a fully running client - self.d.addCallback(self.do_handshake) - - def tearDown(self): - def end(err): - self.leader.Reset() - - if self.node and self.node.connected: - d = self.node.Disconnect() - d.addBoth(end) - return d - else: - end(None) - - def do_handshake(self, node: NeoNode): - self.node = node - raw_version = b"\xb1\xdd\x00\x00version\x00\x00\x00\x00\x00'\x00\x00\x00a\xbb\x9av\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0ef\x9e[mO3\xe7q\x08\x0b/NEO:2.7.4/=\x8b\x00\x00\x01" - raw_verack = b'\xb1\xdd\x00\x00verack\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00]\xf6\xe0\xe2' - node.dataReceived(raw_version + raw_verack) - return node - - def test_connection_lost_by_us(self): - """ - Test that _we_ can force disconnect nodes and cleanup properly - - Expected behaviour: - - added address to DEAD_ADDR list as it's unusable - - removed address from `KNOWN_ADDR` as it's unusable - - stopped all looping tasks of the node - - address not in connected peers list - """ - - def should_not_happen(_): - self.fail("Should not have been called, as our forced disconnection should call the `Errback` on the deferred") - - def conn_lost(_failure, expected_error): - self.assertEqual(type(_failure.value), expected_error) - self.assertIn(self.addr, self.leader.DEAD_ADDRS) - self.assertNotIn(self.addr, self.leader.KNOWN_ADDRS) - self.assertNotIn(self.addr, self.leader.Peers) - - node = self.endpoint.tr.protocol # type: NeoNode - self.assertFalse(node.has_tasks_running()) - - def conn_setup(node: NeoNode): - # at this point we should have a fully connected node, so lets try disconnecting from it - d1 = node.Disconnect() - d1.addCallback(should_not_happen) - d1.addErrback(conn_lost, error.ConnectionDone) - return d1 - - self.d.addCallback(conn_setup) - - return self.d - - def test_connection_lost_normally_by_them(self): - """ - Test handling of a normal connection lost by them (e.g. due to them shutting down) - - Expected behaviour: - - address not in DEAD_ADDR list as it is still useable - - address remains present in `KNOWN_ADDR` as it is still unusable - - stopped all looping tasks of the node - - address not in connected peers list - """ - - def conn_setup(node: NeoNode): - # at this point we should have a fully connected node, so lets try to simulate a connection lost by the other side - with self.assertLogHandler('network', 10) as log: - node.connectionLost(failure.Failure(error.ConnectionDone())) - - self.assertTrue("disconnected normally with reason" in log.output[-1]) - self.assertNotIn(self.addr, self.leader.DEAD_ADDRS) - self.assertIn(self.addr, self.leader.KNOWN_ADDRS) - self.assertNotIn(self.addr, self.leader.Peers) - - self.assertFalse(node.has_tasks_running()) - - self.d.addCallback(conn_setup) - - return self.d - - def test_connection_lost_abnormally_by_them(self): - """ - Test handling of a connection lost by them - - Expected behaviour: - - address not in DEAD_ADDR list as it might still be unusable - - address present in `KNOWN_ADDR` as it might still be unusable - - stopped all looping tasks of the node - - address not in connected peers list - """ - - def conn_setup(node: NeoNode): - # at this point we should have a fully connected node, so lets try to simulate a connection lost by the other side - with self.assertLogHandler('network', 10) as log: - node.connectionLost(failure.Failure(error.ConnectionLost())) - - self.assertIn("disconnected with connectionlost reason", log.output[-1]) - self.assertIn(str(error.ConnectionLost()), log.output[-1]) - self.assertIn("non-clean fashion", log.output[-1]) - - self.assertNotIn(self.addr, self.leader.DEAD_ADDRS) - self.assertIn(self.addr, self.leader.KNOWN_ADDRS) - self.assertNotIn(self.addr, self.leader.Peers) - - self.assertFalse(node.has_tasks_running()) - - self.d.addCallback(conn_setup) - - return self.d - - def test_connection_lost_abnormally_by_them2(self): - """ - Test handling of 2 connection lost events within 5 minutes of each other. - Now we can be more certain that the node is bad or doesn't want to talk to us. - - Expected behaviour: - - address in DEAD_ADDR list as it is unusable - - address not present in `KNOWN_ADDR` as it is unusable - - address not in connected peers list - - stopped all looping tasks of the node - """ - - def conn_setup(node: NeoNode): - # at this point we should have a fully connected node, so lets try to simulate a connection lost by the other side - with self.assertLogHandler('network', 10) as log: - # setup last_connection, to indicate we've lost connection before - node.address.last_connection = Address.Now() # returns a timestamp of utcnow() - - # now lose the connection - node.connectionLost(failure.Failure(error.ConnectionLost())) - - self.assertIn("second connection lost within 5 minutes", log.output[-1]) - self.assertIn(str(error.ConnectionLost()), log.output[-2]) - - self.assertIn(self.addr, self.leader.DEAD_ADDRS) - self.assertNotIn(self.addr, self.leader.KNOWN_ADDRS) - self.assertNotIn(self.addr, self.leader.Peers) - - self.assertFalse(node.has_tasks_running()) - - self.d.addCallback(conn_setup) - - return self.d - - def test_connection_lost_abnormally_by_them3(self): - """ - Test for a premature disconnect - - This means the other side closes connection before the heart_beat threshold exceeded - - Expected behaviour: - - address in DEAD_ADDR list as it is unusable - - address not present in `KNOWN_ADDR` as it is unusable - - address not in connected peers list - - stopped all looping tasks of the node - """ - - def conn_setup(node: NeoNode): - with self.assertLogHandler('network', 10) as log: - # setup last_connection, to indicate we've lost connection before - node.address.last_connection = Address.Now() # returns a timestamp of utcnow() - - # setup the heartbeat data to have last happened 25 seconds ago - # if we disconnect now we should get a premature disconnect - node.start_outstanding_data_request[HEARTBEAT_BLOCKS] = Address.Now() - 25 - - # now lose the connection - node.connectionLost(failure.Failure(error.ConnectionLost())) - - self.assertIn("Premature disconnect", log.output[-2]) - self.assertIn(str(error.ConnectionLost()), log.output[-1]) - - self.assertIn(self.addr, self.leader.DEAD_ADDRS) - self.assertNotIn(self.addr, self.leader.KNOWN_ADDRS) - self.assertNotIn(self.addr, self.leader.Peers) - - self.assertFalse(node.has_tasks_running()) - - self.d.addCallback(conn_setup) - - return self.d diff --git a/neo/Network/test_network1.py b/neo/Network/test_network1.py deleted file mode 100644 index 2b3677bd2..000000000 --- a/neo/Network/test_network1.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -Test Nodeleader basics: starting and stopping - -""" - -from neo.Network.NodeLeader import NodeLeader -from neo.Network.address import Address -from twisted.trial import unittest as twisted_unittest -from twisted.internet import reactor as twisted_reactor -from twisted.internet import error -from mock import MagicMock - - -class NetworkBasicTest(twisted_unittest.TestCase): - def tearDown(self): - NodeLeader.Reset() - - def test_nodeleader_start_stop(self): - orig_connectTCP = twisted_reactor.connectTCP - twisted_reactor.connectTCP = MagicMock() - - seed_list = ['127.0.0.1:80', '127.0.0.2:80'] - leader = NodeLeader.Instance(reactor=twisted_reactor) - - leader.Start(seed_list=seed_list) - self.assertEqual(twisted_reactor.connectTCP.call_count, 2) - self.assertEqual(len(leader.KNOWN_ADDRS), 2) - - for seed, call in zip(seed_list, twisted_reactor.connectTCP.call_args_list): - host, port = seed.split(':') - arg = call[0] - - self.assertEqual(arg[0], host) - self.assertEqual(arg[1], int(port)) - - self.assertTrue(leader.peer_check_loop.running) - self.assertTrue(leader.blockheight_loop.running) - self.assertTrue(leader.memcheck_loop.running) - - leader.Shutdown() - - self.assertFalse(leader.peer_check_loop.running) - self.assertFalse(leader.blockheight_loop.running) - self.assertFalse(leader.memcheck_loop.running) - - # cleanup - twisted_reactor.connectTCP = orig_connectTCP - - def test_nodeleader_start_skip_seeds(self): - orig_connectTCP = twisted_reactor.connectTCP - twisted_reactor.connectTCP = MagicMock() - - seed_list = ['127.0.0.1:80', '127.0.0.2:80'] - leader = NodeLeader(reactor=twisted_reactor) - - leader.Start(seed_list=seed_list, skip_seeds=True) - - self.assertEqual(twisted_reactor.connectTCP.call_count, 0) - self.assertEqual(len(leader.KNOWN_ADDRS), 0) - - self.assertTrue(leader.peer_check_loop.running) - self.assertTrue(leader.blockheight_loop.running) - self.assertTrue(leader.memcheck_loop.running) - - leader.Shutdown() - - # cleanup - twisted_reactor.connectTCP = orig_connectTCP - - def test_connection_refused(self): - """Test handling of a bad address. Where bad could be a dead or unreachable endpoint - - Expected behaviour: - - add address to DEAD_ADDR list as it's unusable - - remove address from KNOWN_ADDR list as it's unusable - """ - leader = NodeLeader.Instance() - - PORT_WITH_NO_SERVICE = 12312 - addr = Address("127.0.0.1:" + str(PORT_WITH_NO_SERVICE)) - - # normally this is done by NodeLeader.Start(), now we add the address manually so we can verify it's removed properly - leader.KNOWN_ADDRS.append(addr) - - def connection_result(value): - self.assertEqual(error.ConnectionRefusedError, value) - self.assertIn(addr, leader.DEAD_ADDRS) - self.assertNotIn(addr, leader.KNOWN_ADDRS) - - d = leader.SetupConnection(addr) # type: Deferred - # leader.clientConnectionFailed() does not rethrow the Failure, therefore we should get the result via the callback, not errback. - # adding both for simplicity. The test will fail on the first assert if the connection was successful. - d.addBoth(connection_result) - - return d diff --git a/neo/Network/test_node.py b/neo/Network/test_node.py deleted file mode 100644 index 455ec475e..000000000 --- a/neo/Network/test_node.py +++ /dev/null @@ -1,110 +0,0 @@ -from unittest import TestCase -from twisted.trial import unittest as twisted_unittest -from neo.Network.NeoNode import NeoNode -from mock import patch -from neo.Network.Payloads.VersionPayload import VersionPayload -from neo.Network.Message import Message -from neo.IO.MemoryStream import StreamManager -from neo.Core.IO.BinaryWriter import BinaryWriter -from neo.Network.NodeLeader import NodeLeader -from twisted.test import proto_helpers - -import sys - - -class Endpoint: - def __init__(self, host, port): - self.host = host - self.port = port - - -# class NodeNetworkingTestCase(twisted_unittest.TestCase): -# def setUp(self): -# factory = NeoClientFactory() -# self.proto = factory.buildProtocol(('127.0.0.1', 0)) -# self.tr = proto_helpers.StringTransport() -# self.proto.makeConnection(self.tr) -# -# def test_max_recursion_on_datareceived(self): -# """ -# TDD: if the data buffer receives network data faster than it can clear it then eventually -# `CheckDataReceived()` in `NeoNode` exceeded the max recursion depth -# """ -# old_limit = sys.getrecursionlimit() -# raw_message = b"\xb1\xdd\x00\x00version\x00\x00\x00\x00\x00'\x00\x00\x00a\xbb\x9av\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0ef\x9e[mO3\xe7q\x08\x0b/NEO:2.7.4/=\x8b\x00\x00\x01" -# -# sys.setrecursionlimit(100) -# # we fill the buffer with 102 packets, which exceeds the 100 recursion depth limit -# self.proto.dataReceived(raw_message * 102) -# # no need to assert anything. If the bug still exists then we get a Python core dump and the process will stop automatically -# # otherwise restore old limit -# sys.setrecursionlimit(old_limit) -# -# def tearDown(self): -# leader = NodeLeader.Instance() -# leader.Peers = [] -# leader.KNOWN_ADDRS = [] - - -class NodeTestCase(TestCase): - - @patch.object(NeoNode, 'MessageReceived') - def test_handle_message(self, mock): - node = NeoNode() - node.endpoint = Endpoint('hello.com', 1234) - node.host = node.endpoint.host - node.port = node.endpoint.port - - payload = VersionPayload(10234, 1234, 'version') - - message = Message('version', payload=payload) - - stream = StreamManager.GetStream() - writer = BinaryWriter(stream) - - message.Serialize(writer) - - out = stream.getvalue() - - print("OUT %s " % out) - - out1 = out[0:10] - out2 = out[10:20] - out3 = out[20:] - - node.dataReceived(out1) - node.dataReceived(out2) - - self.assertEqual(node.buffer_in, out1 + out2) - # import pdb - # pdb.set_trace() - - self.assertEqual(node.bytes_in, 20) - - mock.assert_not_called() - - node.dataReceived(out3) - - self.assertEqual(node.bytes_in, len(out)) - # mock.assert_called_with(message) - - mock.assert_called_once() - - @patch.object(NeoNode, 'SendVersion') - def test_data_received(self, mock): - node = NeoNode() - node.endpoint = Endpoint('hello.com', 1234) - node.host = node.endpoint.host - node.port = node.endpoint.port - payload = VersionPayload(10234, 1234, 'version') - message = Message('version', payload=payload) - stream = StreamManager.GetStream() - writer = BinaryWriter(stream) - message.Serialize(writer) - - out = stream.getvalue() - node.dataReceived(out) - - mock.assert_called_once() - - self.assertEqual(node.Version.Nonce, payload.Nonce) diff --git a/neo/Network/test_node_leader.py b/neo/Network/test_node_leader.py deleted file mode 100644 index b51235580..000000000 --- a/neo/Network/test_node_leader.py +++ /dev/null @@ -1,306 +0,0 @@ -from neo.Utils.WalletFixtureTestCase import WalletFixtureTestCase -from neo.Network.NodeLeader import NodeLeader -from neo.Network.NeoNode import NeoNode -from mock import patch -from neo.Settings import settings -from neo.Core.Blockchain import Blockchain -from neo.Core.UInt160 import UInt160 -from neo.Core.Fixed8 import Fixed8 -from neo.Implementations.Wallets.peewee.UserWallet import UserWallet -from neo.Wallets.utils import to_aes_key -from neo.SmartContract.ContractParameterContext import ContractParametersContext -from neo.Core.TX.Transaction import ContractTransaction, TransactionOutput, TXFeeError -from neo.Core.TX.MinerTransaction import MinerTransaction -from twisted.trial import unittest as twisted_unittest -from twisted.test import proto_helpers -from twisted.internet.address import IPv4Address -from twisted.internet import task -from mock import MagicMock, patch -from neo.api.JSONRPC.JsonRpcApi import JsonRpcApi -from neo.Network.address import Address -from unittest import skip - - -class Endpoint: - def __init__(self, host, port): - self.host = host - self.port = port - - # class NodeLeaderConnectionTest(twisted_unittest.TestCase): - # - # @classmethod - # def setUpClass(cls): - # # clean up left over of other tests classes - # leader = NodeLeader.Instance() - # leader.Peers = [] - # leader.KNOWN_ADDRS = [] - # - # def _add_new_node(self, host, port): - # self.tr.getPeer.side_effect = [IPv4Address('TCP', host, port)] - # node = self.factory.buildProtocol(('127.0.0.1', 0)) - # node.makeConnection(self.tr) # makeConnection also assigns tr to node.transport - # - # return node - # - # def setUp(self): - # self.factory = NeoClientFactory() - # self.tr = proto_helpers.StringTransport() - # self.tr.getPeer = MagicMock() - # self.leader = NodeLeader.Instance() - # - # def test_getpeer_list_vs_maxpeer_list(self): - # """https://github.com/CityOfZion/neo-python/issues/678""" - # settings.set_max_peers(1) - # api_server = JsonRpcApi(None, None) - # # test we start with a clean state - # peers = api_server.get_peers() - # self.assertEqual(len(peers['connected']), 0) - # - # # try connecting more nodes than allowed by the max peers settings - # first_node = self._add_new_node('127.0.0.1', 1111) - # second_node = self._add_new_node('127.0.0.2', 2222) - # peers = api_server.get_peers() - # # should respect max peer setting - # self.assertEqual(1, len(peers['connected'])) - # self.assertEqual('127.0.0.1', peers['connected'][0]['address']) - # self.assertEqual(1111, peers['connected'][0]['port']) - # - # # now drop the existing node - # self.factory.clientConnectionLost(first_node, reason="unittest") - # # add a new one - # second_node = self._add_new_node('127.0.0.2', 2222) - # # and test if `first_node` we dropped can pass limit checks when it reconnects - # self.leader.PeerCheckLoop() - # peers = api_server.get_peers() - # self.assertEqual(1, len(peers['connected'])) - # self.assertEqual('127.0.0.2', peers['connected'][0]['address']) - # self.assertEqual(2222, peers['connected'][0]['port']) - # - # # restore default settings - # settings.set_max_peers(5) - - -class LeaderTestCase(WalletFixtureTestCase): - wallet_1_script_hash = UInt160(data=b'\x1c\xc9\xc0\\\xef\xff\xe6\xcd\xd7\xb1\x82\x81j\x91R\xec!\x8d.\xc0') - - wallet_1_addr = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' - - import_watch_addr = UInt160(data=b'\x08t/\\P5\xac-\x0b\x1c\xb4\x94tIyBu\x7f1*') - watch_addr_str = 'AGYaEi3W6ndHPUmW7T12FFfsbQ6DWymkEm' - _wallet1 = None - - @classmethod - def GetWallet1(cls, recreate=False): - if cls._wallet1 is None or recreate: - cls._wallet1 = UserWallet.Open(LeaderTestCase.wallet_1_dest(), to_aes_key(LeaderTestCase.wallet_1_pass())) - return cls._wallet1 - - @classmethod - def tearDown(cls): - NodeLeader.Instance().Peers = [] - NodeLeader.__LEAD = None - - def test_initialize(self): - leader = NodeLeader.Instance() - self.assertEqual(leader.Peers, []) - self.assertEqual(leader.KNOWN_ADDRS, []) - - # - # @skip("to be updated once new network code is approved") - # def test_peer_adding(self): - # leader = NodeLeader.Instance() - # Blockchain.Default()._block_cache = {'hello': 1} - # - # def mock_call_later(delay, method, *args): - # method(*args) - # - # def mock_connect_tcp(host, port, factory, timeout=120): - # node = NeoNode() - # node.endpoint = Endpoint(host, port) - # leader.AddConnectedPeer(node) - # return node - # - # def mock_disconnect(peer): - # return True - # - # def mock_send_msg(node, message): - # return True - # - # settings.set_max_peers(len(settings.SEED_LIST)) - # - # with patch('twisted.internet.reactor.connectTCP', mock_connect_tcp): - # with patch('twisted.internet.reactor.callLater', mock_call_later): - # with patch('neo.Network.NeoNode.NeoNode.Disconnect', mock_disconnect): - # with patch('neo.Network.NeoNode.NeoNode.SendSerializedMessage', mock_send_msg): - # leader.Start() - # self.assertEqual(len(leader.Peers), len(settings.SEED_LIST)) - # - # # now test adding another - # leader.RemoteNodePeerReceived('hello.com', 1234, 6) - # - # # it shouldnt add anything so it doesnt go over max connected peers - # self.assertEqual(len(leader.Peers), len(settings.SEED_LIST)) - # - # # test adding peer - # peer = NeoNode() - # peer.endpoint = Endpoint('hellloo.com', 12344) - # leader.KNOWN_ADDRS.append(Address('hellloo.com:12344')) - # leader.AddConnectedPeer(peer) - # self.assertEqual(len(leader.Peers), len(settings.SEED_LIST)) - # - # # now get a peer - # peer = leader.Peers[0] - # - # leader.RemoveConnectedPeer(peer) - # - # # the connect peers should be 1 less than the seed_list - # self.assertEqual(len(leader.Peers), len(settings.SEED_LIST) - 1) - # # the known addresses should be equal the seed_list - # self.assertEqual(len(leader.KNOWN_ADDRS), len(settings.SEED_LIST)) - # - # # now test adding another - # leader.RemoteNodePeerReceived('hello.com', 1234, 6) - # - # self.assertEqual(len(leader.Peers), len(settings.SEED_LIST)) - # - # # now if we remove all peers, it should restart - # peers = leader.Peers[:] - # for peer in peers: - # leader.RemoveConnectedPeer(peer) - # - # # test reset - # # leader.ResetBlockRequestsAndCache() - # # self.assertEqual(Blockchain.Default()._block_cache, {}) - # - # # test shutdown - # leader.Shutdown() - # self.assertEqual(len(leader.Peers), 0) - # - def _generate_tx(self, amount): - wallet = self.GetWallet1() - - output = TransactionOutput(AssetId=Blockchain.SystemShare().Hash, Value=amount, - script_hash=LeaderTestCase.wallet_1_script_hash) - contract_tx = ContractTransaction(outputs=[output]) - try: - wallet.MakeTransaction(contract_tx) - except (ValueError, TXFeeError): - pass - ctx = ContractParametersContext(contract_tx) - wallet.Sign(ctx) - contract_tx.scripts = ctx.GetScripts() - return contract_tx - - # @skip("to be updated once new network code is approved") - # def test_relay(self): - # leader = NodeLeader.Instance() - # - # def mock_call_later(delay, method, *args): - # method(*args) - # - # def mock_connect_tcp(host, port, factory, timeout=120): - # node = NeoNode() - # node.endpoint = Endpoint(host, port) - # leader.AddConnectedPeer(node) - # return node - # - # def mock_send_msg(node, message): - # return True - # - # with patch('twisted.internet.reactor.connectTCP', mock_connect_tcp): - # with patch('twisted.internet.reactor.callLater', mock_call_later): - # with patch('neo.Network.NeoNode.NeoNode.SendSerializedMessage', mock_send_msg): - # leader.Start() - # - # miner = MinerTransaction() - # - # res = leader.Relay(miner) - # self.assertFalse(res) - # - # tx = self._generate_tx(Fixed8.One()) - # - # res = leader.Relay(tx) - # self.assertEqual(res, True) - # - # self.assertTrue(tx.Hash.ToBytes() in leader.MemPool.keys()) - # res2 = leader.Relay(tx) - # self.assertFalse(res2) - # - def test_inventory_received(self): - - leader = NodeLeader.Instance() - - miner = MinerTransaction() - miner.Nonce = 1234 - res = leader.InventoryReceived(miner) - - self.assertFalse(res) - - block = Blockchain.Default().GenesisBlock() - - res2 = leader.InventoryReceived(block) - - self.assertFalse(res2) - - tx = self._generate_tx(Fixed8.TryParse(15)) - - res = leader.InventoryReceived(tx) - - self.assertIsNone(res) - - def _add_existing_tx(self): - wallet = self.GetWallet1() - - existing_tx = None - for tx in wallet.GetTransactions(): - existing_tx = tx - break - - self.assertNotEqual(None, existing_tx) - - # add the existing tx to the mempool - NodeLeader.Instance().MemPool[tx.Hash.ToBytes()] = tx - - def _clear_mempool(self): - txs = [] - values = NodeLeader.Instance().MemPool.values() - for tx in values: - txs.append(tx) - - for tx in txs: - del NodeLeader.Instance().MemPool[tx.Hash.ToBytes()] - - def test_get_transaction(self): - # delete any tx in the mempool - self._clear_mempool() - - # generate a new tx - tx = self._generate_tx(Fixed8.TryParse(5)) - - # try to get it - res = NodeLeader.Instance().GetTransaction(tx.Hash.ToBytes()) - self.assertIsNone(res) - - # now add it to the mempool - NodeLeader.Instance().MemPool[tx.Hash.ToBytes()] = tx - - # and try to get it - res = NodeLeader.Instance().GetTransaction(tx.Hash.ToBytes()) - self.assertTrue(res is tx) - - def test_mempool_check_loop(self): - # delete any tx in the mempool - self._clear_mempool() - - # add a tx which is already confirmed - self._add_existing_tx() - - # and add a tx which is not confirmed - tx = self._generate_tx(Fixed8.TryParse(20)) - NodeLeader.Instance().MemPool[tx.Hash.ToBytes()] = tx - - # now remove the confirmed tx - NodeLeader.Instance().MempoolCheck() - - self.assertEqual( - len(list(map(lambda hash: "0x%s" % hash.decode('utf-8'), NodeLeader.Instance().MemPool.keys()))), 1) diff --git a/neo/Network/tests/__init__.py b/neo/Network/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/neo/Network/tests/test_ipfilter.py b/neo/Network/tests/test_ipfilter.py new file mode 100644 index 000000000..1334083b0 --- /dev/null +++ b/neo/Network/tests/test_ipfilter.py @@ -0,0 +1,143 @@ +import unittest +from neo.Network.ipfilter import IPFilter + + +class IPFilteringTestCase(unittest.TestCase): + def test_nobody_allowed(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + ] + } + + self.assertFalse(filter.is_allowed('127.0.0.1')) + self.assertFalse(filter.is_allowed('10.10.10.10')) + + def test_nobody_allowed_except_one(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + '10.10.10.10' + ] + } + + self.assertFalse(filter.is_allowed('127.0.0.1')) + self.assertFalse(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + def test_everybody_allowed(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + ], + 'whitelist': [ + ] + } + + self.assertTrue(filter.is_allowed('127.0.0.1')) + self.assertTrue(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + filter.config = { + 'blacklist': [ + ], + 'whitelist': [ + '0.0.0.0/0' + ] + } + + self.assertTrue(filter.is_allowed('127.0.0.1')) + self.assertTrue(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + filter.config = { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + '0.0.0.0/0' + ] + } + + self.assertTrue(filter.is_allowed('127.0.0.1')) + self.assertTrue(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + def test_everybody_allowed_except_one(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + '127.0.0.1' + ], + 'whitelist': [ + ] + } + + self.assertFalse(filter.is_allowed('127.0.0.1')) + self.assertTrue(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + def test_disallow_ip_range(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + '127.0.0.0/24' + ], + 'whitelist': [ + ] + } + + self.assertFalse(filter.is_allowed('127.0.0.0')) + self.assertFalse(filter.is_allowed('127.0.0.1')) + self.assertFalse(filter.is_allowed('127.0.0.100')) + self.assertFalse(filter.is_allowed('127.0.0.255')) + self.assertTrue(filter.is_allowed('10.10.10.11')) + self.assertTrue(filter.is_allowed('10.10.10.10')) + + def test_updating_blacklist(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + ], + 'whitelist': [ + ] + } + + self.assertTrue(filter.is_allowed('127.0.0.1')) + + filter.blacklist_add('127.0.0.0/24') + self.assertFalse(filter.is_allowed('127.0.0.1')) + # should have no effect, only exact matches + filter.blacklist_remove('127.0.0.1') + self.assertFalse(filter.is_allowed('127.0.0.1')) + + filter.blacklist_remove('127.0.0.0/24') + self.assertTrue(filter.is_allowed('127.0.0.1')) + + def test_updating_whitelist(self): + filter = IPFilter() + filter.config = { + 'blacklist': [ + '0.0.0.0/0' + ], + 'whitelist': [ + ] + } + + self.assertFalse(filter.is_allowed('127.0.0.1')) + + filter.whitelist_add('127.0.0.0/24') + self.assertTrue(filter.is_allowed('127.0.0.1')) + + filter.whitelist_remove('127.0.0.1') + # should have no effect, only exact matches + self.assertTrue(filter.is_allowed('127.0.0.1')) + + filter.whitelist_remove('127.0.0.0/24') + self.assertFalse(filter.is_allowed('127.0.0.1')) diff --git a/neo/Network/tests/test_syncmanager1.py b/neo/Network/tests/test_syncmanager1.py new file mode 100644 index 000000000..6e5165090 --- /dev/null +++ b/neo/Network/tests/test_syncmanager1.py @@ -0,0 +1,55 @@ +import asynctest +import asyncio +from neo.Network.syncmanager import SyncManager + + +class SyncManagerTestCase(asynctest.TestCase): + async def test_start(self): + syncmgr = SyncManager(asynctest.MagicMock) + syncmgr.reset() + + # mock values + syncmgr.nodemgr.running = True + syncmgr.run_service = asynctest.CoroutineMock() + syncmgr.block_health = asynctest.CoroutineMock() + + # run + await syncmgr.start() + self.assertNotEqual(syncmgr.service_task, None) + self.assertNotEqual(syncmgr.health_task, None) + + +class ShutdownSyncManagerTests(asynctest.TestCase): + def setUp(self) -> None: + self.syncmgr = SyncManager(asynctest.MagicMock) + self.syncmgr.reset() + + # mock values + self.syncmgr.nodemgr.running = True + # self.syncmgr.run_service = asynctest.CoroutineMock() + self.syncmgr.block_health = asynctest.CoroutineMock() + + async def test_normal_shutdown(self): + # ensure the coroutine throws a CancelledError as that's the normal response to .cancel() + # this makes it testable + self.syncmgr.block_health.side_effect = asyncio.CancelledError() + + # run + await self.syncmgr.start() + await self.syncmgr.shutdown() + + self.assertTrue(self.syncmgr.health_task.cancelled()) + + async def test_shutdown_with_service_exception(self): + self.syncmgr.raise_exception = False + self.syncmgr.block_health.side_effect = asyncio.CancelledError() + self.syncmgr.check_timeout = asynctest.CoroutineMock() + self.syncmgr.sync = asynctest.CoroutineMock() + + await self.syncmgr.start() + await asyncio.sleep(0.5) + self.syncmgr.sync.side_effect = Exception() + self.syncmgr.raise_exception = True + await asyncio.sleep(0.5) + await self.syncmgr.shutdown() + self.assertTrue(self.syncmgr.health_task.cancelled()) diff --git a/neo/Network/tests/test_syncmanager2.py b/neo/Network/tests/test_syncmanager2.py new file mode 100644 index 000000000..e2065bf37 --- /dev/null +++ b/neo/Network/tests/test_syncmanager2.py @@ -0,0 +1,168 @@ +import asynctest +from logging import DEBUG +from neo.Network.syncmanager import SyncManager +from neo.Utils.NeoTestCase import NeoTestCase +from neo.Network.flightinfo import FlightInfo +from neo.Network.requestinfo import RequestInfo +from neo.Network.nodemanager import NeoNode + + +class TimeoutSyncMgrTestCase(NeoTestCase, asynctest.TestCase): + def setUp(self) -> None: + # we have to override the singleton behaviour or our coroutine mocks will persist + with asynctest.patch('neo.Network.syncmanager.SyncManager.__new__', return_value=object.__new__(SyncManager)): + self.syncmgr = SyncManager() + self.syncmgr.init(asynctest.MagicMock) + self.syncmgr.reset() + + async def test_header_exception(self): + self.syncmgr.check_block_timeout = asynctest.CoroutineMock() + + self.syncmgr.check_header_timeout = asynctest.CoroutineMock() + self.syncmgr.check_header_timeout.side_effect = Exception("unittest exception") + + with self.assertLogHandler('syncmanager', DEBUG) as context: + await self.syncmgr.check_timeout() + self.assertGreater(len(context.output), 0) + self.assertTrue("unittest exception" in context.output[0]) + + async def test_block_exception(self): + self.syncmgr.check_block_timeout = asynctest.CoroutineMock() + self.syncmgr.check_block_timeout.side_effect = Exception("unittest exception") + + self.syncmgr.check_header_timeout = asynctest.CoroutineMock() + + with self.assertLogHandler('syncmanager', DEBUG) as context: + await self.syncmgr.check_timeout() + self.assertGreater(len(context.output), 0) + self.assertTrue("unittest exception" in context.output[0]) + + async def test_no_outstanding_header_requests(self): + # should return immediately + self.syncmgr.header_request = asynctest.MagicMock() + self.syncmgr.header_request.__bool__.return_value = False + await self.syncmgr.check_timeout() + + self.assertFalse(self.syncmgr.header_request.most_recent_flight.called) + + async def test_outstanding_header_request_within_boundaries(self): + # should return early because we have not exceeded the threshold + cur_header_height = 1 + node_id = 123 + + self.syncmgr.nodemgr = asynctest.MagicMock() + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + self.syncmgr.header_request.add_new_flight(FlightInfo(node_id, cur_header_height + 1)) + + await self.syncmgr.check_timeout() + + self.assertFalse(self.syncmgr.nodemgr.get_node_by_nodeid.called) + + async def test_outstanding_header_request_timedout_but_received(self): + """ + test an outstanding request that timedout, but has been received in the meantime + """ + cur_header_height = 1 + node_id = 123 + + # mock node manager state + self.syncmgr.nodemgr = asynctest.MagicMock() + node1 = NeoNode(object(), object()) + node1.nodeid = node_id + self.syncmgr.nodemgr.get_node_by_id.return_value = node1 + + # mock ledger state + self.syncmgr.ledger = asynctest.MagicMock() + # we pretend our local ledger has a height higher than what we just asked for + self.syncmgr.ledger.cur_header_height = asynctest.CoroutineMock(return_value=3) + + # setup sync manager state to have an outstanding header request + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + fi = FlightInfo(node_id, cur_header_height + 1) + fi.start_time = fi.start_time - 5 # decrease start time by 5 seconds to exceed timeout threshold + self.syncmgr.header_request.add_new_flight(fi) + + with self.assertLogHandler('syncmanager', DEBUG) as log_context: + await self.syncmgr.check_timeout() + self.assertGreater(len(log_context.output), 0) + self.assertTrue("Header timeout limit exceed" in log_context.output[0]) + + self.assertIsNone(self.syncmgr.header_request) + + async def test_outstanding_header_request_timedout(self): + """ + test an outstanding request that timedout, but for which we cannot ask another node + conditions: + - current node exceeded MAX_TIMEOUT_COUNT and will be disconencted + - no other nodes are connected that have our desired height + + Expected: + We expect to return without setting up another header request with a new node. The node manager + should resolve getting new nodes + """ + cur_header_height = 1 + node_id = 123 + + # mock node manager state + self.syncmgr.nodemgr = asynctest.MagicMock() + node1 = NeoNode(object(), object()) + node1.nodeid = node_id + self.syncmgr.nodemgr.get_node_by_id.return_value = node1 + # returning None indicates we have no more nodes connected with our desired height + self.syncmgr.nodemgr.get_node_with_min_failed_time.return_value = None + self.syncmgr.nodemgr.add_node_timeout_count = asynctest.CoroutineMock() + # ------- + + # mock ledger state + self.syncmgr.ledger = asynctest.MagicMock() + # we pretend our local ledger has a height higher than what we just asked for + self.syncmgr.ledger.cur_header_height = asynctest.CoroutineMock(return_value=1) + # ------ + + # setup sync manager state to have an outstanding header request + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + fi = FlightInfo(node_id, cur_header_height + 1) + fi.start_time = fi.start_time - 5 # decrease start time by 5 seconds to exceed timeout threshold + self.syncmgr.header_request.add_new_flight(fi) + + with self.assertLogHandler('syncmanager', DEBUG) as log_context: + await self.syncmgr.check_timeout() + self.assertGreater(len(log_context.output), 0) + self.assertTrue("Header timeout limit exceed" in log_context.output[0]) + + self.assertTrue(self.syncmgr.nodemgr.get_node_with_min_failed_time.called) + self.assertIsNone(self.syncmgr.header_request) + + async def test_outstanding_request_timedout(self): + cur_header_height = 1 + node_id = 123 + + # mock node manager state + self.syncmgr.nodemgr = asynctest.MagicMock() + node1 = NeoNode(object(), object()) + node1.nodeid = node_id + self.syncmgr.nodemgr.get_node_by_id.return_value = node1 + + node2 = asynctest.MagicMock() # NeoNode(object(), object()) + node2.nodeid.return_value = 456 + node2.get_headers = asynctest.CoroutineMock() + self.syncmgr.nodemgr.get_node_with_min_failed_time.return_value = node2 + self.syncmgr.nodemgr.add_node_timeout_count = asynctest.CoroutineMock() + + # mock ledger state + self.syncmgr.ledger = asynctest.MagicMock() + # we pretend our local ledger has a height higher than what we just asked for + self.syncmgr.ledger.cur_header_height = asynctest.CoroutineMock(return_value=1) + self.syncmgr.ledger.header_hash_by_height = asynctest.CoroutineMock(return_value=b'') + # ------ + + # setup sync manager state to have an outstanding header request + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + fi = FlightInfo(node_id, cur_header_height + 1) + fi.start_time = fi.start_time - 5 # decrease start time by 5 seconds to exceed timeout threshold + self.syncmgr.header_request.add_new_flight(fi) + + with self.assertLogHandler('syncmanager', DEBUG) as log_context: + await self.syncmgr.check_timeout() + self.assertGreater(len(log_context.output), 0) + self.assertTrue("Retry requesting headers starting at 2" in log_context.output[-1]) diff --git a/neo/Network/tests/test_syncmanager3.py b/neo/Network/tests/test_syncmanager3.py new file mode 100644 index 000000000..0c9b32fa6 --- /dev/null +++ b/neo/Network/tests/test_syncmanager3.py @@ -0,0 +1,130 @@ +import asynctest +import asyncio +import os +from logging import DEBUG +from neo.Network.syncmanager import SyncManager +from neo.Network.flightinfo import FlightInfo +from neo.Network.requestinfo import RequestInfo +from neo.Network.core.header import Header +from neo.Network.ledger import Ledger +from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase +from neo.Settings import settings +from neo.Network.core.uint256 import UInt256 +from neo.Network.core.uint160 import UInt160 + + +class HeadersReceivedSyncMgrTestCase(BlockchainFixtureTestCase, asynctest.TestCase): + @classmethod + def leveldb_testpath(self): + return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') + + def setUp(self) -> None: + # we have to override the singleton behaviour or our coroutine mocks will persist + with asynctest.patch('neo.Network.syncmanager.SyncManager.__new__', return_value=object.__new__(SyncManager)): + self.syncmgr = SyncManager() + self.syncmgr.init(asynctest.MagicMock) + self.syncmgr.reset() + + async def test_empty_header_list(self): + res = await self.syncmgr.on_headers_received(123, []) + self.assertEqual(res, -1) + + async def test_unexpected_headers_received(self): + # headers received while we have no outstanding request should be early ignored + + self.syncmgr.header_request = None + res = await self.syncmgr.on_headers_received(123, [object()]) + self.assertEqual(res, -2) + + async def test_headers_received_not_matching_requested_height(self): + cur_header_height = 1 + node_id = 123 + + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + self.syncmgr.header_request.add_new_flight(FlightInfo(node_id, cur_header_height + 1)) + + height = 123123 + header = Header(object(), object(), 0, height, object(), object(), object()) + res = await self.syncmgr.on_headers_received(123, [header]) + self.assertEqual(res, -3) + + async def test_headers_received_outdated_height(self): + # test that a slow response that has been superseeded by a fast response + # from another node does not get processed twice + cur_header_height = 1 + node_id = 123 + + self.syncmgr.header_request = RequestInfo(cur_header_height + 1) + self.syncmgr.header_request.add_new_flight(FlightInfo(node_id, cur_header_height + 1)) + + height = 2 + header = Header(object(), object(), 0, height, object(), object(), object()) + + # mock ledger state + self.syncmgr.ledger = asynctest.MagicMock() + self.syncmgr.ledger.cur_header_height = asynctest.CoroutineMock(return_value=2) + + with self.assertLogHandler('syncmanager', DEBUG) as log_context: + res = await self.syncmgr.on_headers_received(123, [header]) + self.assertEqual(res, -5) + self.assertGreater(len(log_context.output), 0) + self.assertTrue("Headers received 2 - 2" in log_context.output[0]) + + +class HeadersReceivedSyncMgrTestCase2(BlockchainFixtureTestCase, asynctest.TestCase): + """ + For the final test we need to use a new fixture + """ + + @classmethod + def leveldb_testpath(self): + return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') + + def setUp(self) -> None: + # we have to override the singleton behaviour or our coroutine mocks will persist + with asynctest.patch('neo.Network.syncmanager.SyncManager.__new__', return_value=object.__new__(SyncManager)): + self.syncmgr = SyncManager() + self.syncmgr.init(asynctest.MagicMock) + self.syncmgr.reset() + + def test_simultaneous_same_header_received(self): + """ + test ensures that we do not waste computing sources processing the same headers multiple times + expected result is 1 "processed" event (return value 1) and 4 early exit events (return value -4) + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self.syncmgr.ledger = Ledger() + self.syncmgr.nodemgr.add_node_error_count = asynctest.CoroutineMock() + + height = 12357 + node_id = 123 + + self.syncmgr.header_request = RequestInfo(height) + self.syncmgr.header_request.add_new_flight(FlightInfo(node_id, height)) + + fake_uint256 = UInt256(data=bytearray(32)) + fake_uint160 = UInt160(data=bytearray(20)) + not_used = object() + + # create 2000 headers that can be persisted + headers = [] + for i in range(2000): + headers.append(Header(fake_uint256, fake_uint256, 0, height + i, 0, fake_uint160, not_used)) + + # create 5 tasks to schedule incoming headers + tasks = [] + for i in range(5): + tasks.append(loop.create_task(self.syncmgr.on_headers_received(i, headers))) + + # run all tasks + try: + results = loop.run_until_complete(asyncio.gather(*tasks)) + finally: + loop.close() + + # assert that only the first one gets fully processed, the rest not + success = 1 + already_exist = -4 + expected_results = [success, already_exist, already_exist, already_exist, already_exist] + self.assertEqual(results, expected_results) diff --git a/neo/Network/utils.py b/neo/Network/utils.py new file mode 100644 index 000000000..86be7723c --- /dev/null +++ b/neo/Network/utils.py @@ -0,0 +1,24 @@ +import socket +import ipaddress + + +def hostname_to_ip(hostname: str) -> str: + """ + Args: + hostname: e.g. seed1.ngd.network + + Raises: + socket.gaierror if hostname could not be resolved + Returns: host e.g. 10.1.1.3 + + """ + return socket.gethostbyname(hostname) + + +def is_ip_address(hostname: str) -> bool: + host = hostname.split(':')[0] + try: + ip = ipaddress.ip_address(host) + return True + except ValueError: + return False diff --git a/neo/Prompt/Commands/Bootstrap.py b/neo/Prompt/Commands/Bootstrap.py index 726c5b095..bc89dabb2 100644 --- a/neo/Prompt/Commands/Bootstrap.py +++ b/neo/Prompt/Commands/Bootstrap.py @@ -1,6 +1,6 @@ import sys from neo.Settings import settings -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt import requests from tqdm import tqdm import tarfile diff --git a/neo/Prompt/Commands/BuildNRun.py b/neo/Prompt/Commands/BuildNRun.py index a4aaeeb85..5b6f5f73e 100644 --- a/neo/Prompt/Commands/BuildNRun.py +++ b/neo/Prompt/Commands/BuildNRun.py @@ -52,7 +52,7 @@ def Build(arguments): return contract_script -def BuildAndRun(arguments, wallet, verbose=True, min_fee=DEFAULT_MIN_FEE, invocation_test_mode=True): +def BuildAndRun(arguments, wallet, verbose=True, min_fee=DEFAULT_MIN_FEE, invocation_test_mode=True, enable_debugger=False): arguments, from_addr = get_from_addr(arguments) arguments, invoke_attrs = get_tx_attr_from_args(arguments) arguments, owners = get_owners_from_params(arguments) @@ -69,7 +69,7 @@ def BuildAndRun(arguments, wallet, verbose=True, min_fee=DEFAULT_MIN_FEE, invoca return DoRun(contract_script, arguments, wallet, path, verbose, from_addr, min_fee, invocation_test_mode, - debug_map=debug_map, invoke_attrs=invoke_attrs, owners=owners) + debug_map=debug_map, invoke_attrs=invoke_attrs, owners=owners, enable_debugger=enable_debugger) else: print('Please check the path to your Python (.py) file to compile') return None, None, None, None @@ -77,7 +77,7 @@ def BuildAndRun(arguments, wallet, verbose=True, min_fee=DEFAULT_MIN_FEE, invoca def DoRun(contract_script, arguments, wallet, path, verbose=True, from_addr=None, min_fee=DEFAULT_MIN_FEE, invocation_test_mode=True, - debug_map=None, invoke_attrs=None, owners=None): + debug_map=None, invoke_attrs=None, owners=None, enable_debugger=False): if not wallet: print("Please open a wallet to test build contract") return None, None, None, None @@ -92,7 +92,7 @@ def DoRun(contract_script, arguments, wallet, path, verbose=True, tx, result, total_ops, engine = test_deploy_and_invoke(script, i_args, wallet, from_addr, min_fee, invocation_test_mode, debug_map=debug_map, - invoke_attrs=invoke_attrs, owners=owners) + invoke_attrs=invoke_attrs, owners=owners, enable_debugger=enable_debugger) i_args.reverse() return_type_results = [] @@ -104,7 +104,7 @@ def DoRun(contract_script, arguments, wallet, path, verbose=True, except Exception: raise TypeError - if tx and result: + if tx is not None and result is not None: if verbose: print("\n-----------------------------------------------------------") print("Calling %s with arguments %s " % (path, [item for item in reversed(engine.invocation_args)])) diff --git a/neo/Prompt/Commands/Config.py b/neo/Prompt/Commands/Config.py index 4832de89d..7e88c3fc2 100644 --- a/neo/Prompt/Commands/Config.py +++ b/neo/Prompt/Commands/Config.py @@ -1,12 +1,13 @@ -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.logging import log_manager from neo.Prompt.CommandBase import CommandBase, CommandDesc, ParameterDesc from neo.Prompt.Utils import get_arg from neo.Settings import settings -from neo.Network.NodeLeader import NodeLeader from neo.Prompt.PromptPrinter import prompt_print as print from distutils import util +from neo.Network.nodemanager import NodeManager import logging +from neo.Network.common import wait_for class CommandConfig(CommandBase): @@ -17,8 +18,8 @@ def __init__(self): self.register_sub_command(CommandConfigSCEvents()) self.register_sub_command(CommandConfigDebugNotify()) self.register_sub_command(CommandConfigVMLog()) - self.register_sub_command(CommandConfigNodeRequests()) self.register_sub_command(CommandConfigMaxpeers()) + self.register_sub_command(CommandConfigMinpeers()) self.register_sub_command(CommandConfigNEP8()) def command_desc(self): @@ -138,35 +139,6 @@ def command_desc(self): return CommandDesc('vm-log', 'toggle VM instruction execution logging to file', [p1]) -class CommandConfigNodeRequests(CommandBase): - def __init__(self): - super().__init__() - - def execute(self, arguments): - if len(arguments) in [1, 2]: - if len(arguments) == 2: - try: - return NodeLeader.Instance().setBlockReqSizeAndMax(int(arguments[0]), int(arguments[1])) - except ValueError: - print("Invalid values. Please specify a block request part and max size for each node, like 30 and 1000") - return False - elif len(arguments) == 1: - return NodeLeader.Instance().setBlockReqSizeByName(arguments[0]) - else: - print("Please specify the required parameter") - return False - - def command_desc(self): - p1 = ParameterDesc('block-size', 'preset of "slow"/"normal"/"fast", or a specific block request size (max. 500) e.g. 250 ') - p2 = ParameterDesc('queue-size', 'maximum number of outstanding block requests') - return CommandDesc('node-requests', 'configure block request settings', [p1, p2]) - - def handle_help(self, arguments): - super().handle_help(arguments) - print(f"\nCurrent settings {self.command_desc().params[0].name}:" - f" {NodeLeader.Instance().BREQPART} {self.command_desc().params[1].name}: {NodeLeader.Instance().BREQMAX}") - - class CommandConfigNEP8(CommandBase): def __init__(self): super().__init__() @@ -202,24 +174,36 @@ def __init__(self): def execute(self, arguments): c1 = get_arg(arguments) if c1 is not None: + try: - current_max = settings.CONNECTED_PEER_MAX - settings.set_max_peers(c1) c1 = int(c1) - p_len = len(NodeLeader.Instance().Peers) - if c1 < current_max and c1 < p_len: - to_remove = p_len - c1 - peers = NodeLeader.Instance().Peers - for i in range(to_remove): - peer = peers[-1] # disconnect last peer added first - peer.Disconnect("Max connected peers reached", isDead=False) - peers.pop() - - print(f"Maxpeers set to {c1}") - return c1 + except ValueError: + print("Invalid argument") + return + + if c1 > 10: + print("Max peers is limited to 10") + return + + try: + settings.set_max_peers(c1) except ValueError: print("Please supply a positive integer for maxpeers") return + + nodemgr = NodeManager() + nodemgr.max_clients = c1 + + current_max = settings.CONNECTED_PEER_MAX + connected_count = len(nodemgr.nodes) + if c1 < current_max and c1 < connected_count: + to_remove = connected_count - c1 + for _ in range(to_remove): + last_connected_node = nodemgr.nodes[-1] + wait_for(last_connected_node.disconnect()) # need to avoid it being labelled as dead/bad + + print(f"Maxpeers set to {c1}") + return c1 else: print(f"Maintaining maxpeers at {settings.CONNECTED_PEER_MAX}") return @@ -229,6 +213,33 @@ def command_desc(self): return CommandDesc('maxpeers', 'configure number of max peers', [p1]) +class CommandConfigMinpeers(CommandBase): + def __init__(self): + super().__init__() + + def execute(self, arguments): + c1 = get_arg(arguments) + if c1 is not None: + try: + c1 = int(c1) + if c1 > settings.CONNECTED_PEER_MAX: + print('minpeers setting cannot be bigger than maxpeers setting') + return + settings.set_min_peers(c1) + except ValueError: + print("Please supply a positive integer for minpeers") + return + print(f"Minpeers set to {c1}") + return c1 + else: + print(f"Maintaining minpeers at {settings.CONNECTED_PEER_MIN}") + return + + def command_desc(self): + p1 = ParameterDesc('number', 'minimum number of nodes to connect to') + return CommandDesc('minpeers', 'configure number of min peers', [p1]) + + def start_output_config(): # temporarily mute stdout while we try to reconfigure our settings # components like `network` set at DEBUG level will spam through the console diff --git a/neo/Prompt/Commands/Invoke.py b/neo/Prompt/Commands/Invoke.py index dfb1081ab..ea38b80a1 100644 --- a/neo/Prompt/Commands/Invoke.py +++ b/neo/Prompt/Commands/Invoke.py @@ -3,7 +3,6 @@ from neo.Blockchain import GetBlockchain from neo.VM.ScriptBuilder import ScriptBuilder from neo.VM.InteropService import InteropInterface -from neo.Network.NodeLeader import NodeLeader from neo.Prompt import Utils as PromptUtils from neo.Implementations.Blockchains.LevelDB.DBCollection import DBCollection from neo.Implementations.Blockchains.LevelDB.DBPrefix import DBPrefix @@ -31,16 +30,18 @@ from neo.Settings import settings from neo.Core.Blockchain import Blockchain from neo.EventHub import events -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from copy import deepcopy from neo.logging import log_manager from neo.Prompt.PromptPrinter import prompt_print as print +from neo.Network.nodemanager import NodeManager logger = log_manager.getLogger() from neo.Core.Cryptography.ECCurve import ECDSA from neo.Core.UInt160 import UInt160 from neo.VM.OpCode import PACK +from neo.VM.Debugger import Debugger DEFAULT_MIN_FEE = Fixed8.FromDecimal(.0001) @@ -77,9 +78,8 @@ def InvokeContract(wallet, tx, fee=Fixed8.Zero(), from_addr=None, owners=None): relayed = False - # print("SENDING TX: %s " % json.dumps(wallet_tx.ToJson(), indent=4)) - - relayed = NodeLeader.Instance().Relay(wallet_tx) + nodemgr = NodeManager() + relayed = nodemgr.relay(wallet_tx) if relayed: print("Relayed Tx: %s " % wallet_tx.Hash.ToString()) @@ -140,7 +140,8 @@ def InvokeWithTokenVerificationScript(wallet, tx, token, fee=Fixed8.Zero(), invo wallet_tx.scripts = context.GetScripts() - relayed = NodeLeader.Instance().Relay(wallet_tx) + nodemgr = NodeManager() + relayed = nodemgr.relay(wallet_tx) if relayed: print("Relayed Tx: %s " % wallet_tx.Hash.ToString()) @@ -409,7 +410,7 @@ def test_invoke(script, wallet, outputs, withdrawal_tx=None, def test_deploy_and_invoke(deploy_script, invoke_args, wallet, from_addr=None, min_fee=DEFAULT_MIN_FEE, invocation_test_mode=True, - debug_map=None, invoke_attrs=None, owners=None): + debug_map=None, invoke_attrs=None, owners=None, enable_debugger=False): bc = GetBlockchain() accounts = DBCollection(bc._db, DBPrefix.ST_Account, AccountState) @@ -464,8 +465,11 @@ def test_deploy_and_invoke(deploy_script, invoke_args, wallet, # first we will execute the test deploy # then right after, we execute the test invoke - - d_success = engine.Execute() + if enable_debugger: + debugger = Debugger(engine) + d_success = debugger.Execute() + else: + d_success = engine.Execute() if d_success: @@ -593,7 +597,11 @@ def test_deploy_and_invoke(deploy_script, invoke_args, wallet, engine.LoadScript(itx.Script) engine.LoadDebugInfoForScriptHash(debug_map, shash.Data) - i_success = engine.Execute() + if enable_debugger: + debugger = Debugger(engine) + i_success = debugger.Execute() + else: + i_success = engine.Execute() service.ExecutionCompleted(engine, i_success) to_dispatch = to_dispatch + service.events_to_dispatch diff --git a/neo/Prompt/Commands/LoadSmartContract.py b/neo/Prompt/Commands/LoadSmartContract.py index 5c6d2907a..96f976a48 100644 --- a/neo/Prompt/Commands/LoadSmartContract.py +++ b/neo/Prompt/Commands/LoadSmartContract.py @@ -3,7 +3,7 @@ from neo.Core.FunctionCode import FunctionCode from neo.Core.State.ContractState import ContractPropertyState from neo.SmartContract.ContractParameterType import ContractParameterType -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt import json from neo.VM.ScriptBuilder import ScriptBuilder from neo.Core.Blockchain import Blockchain diff --git a/neo/Prompt/Commands/SC.py b/neo/Prompt/Commands/SC.py index c15cd6298..8076e9a5a 100644 --- a/neo/Prompt/Commands/SC.py +++ b/neo/Prompt/Commands/SC.py @@ -8,7 +8,7 @@ from neo.Core.UInt160 import UInt160 from neo.SmartContract.ContractParameter import ContractParameter from neo.SmartContract.ContractParameterType import ContractParameterType -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.Core.Fixed8 import Fixed8 from neo.Implementations.Blockchains.LevelDB.DebugStorage import DebugStorage from distutils import util @@ -186,11 +186,12 @@ def execute(self, arguments): return False tx, fee, results, num_ops, engine_success = TestInvokeContract(wallet, arguments, from_addr=from_addr, invoke_attrs=invoke_attrs, owners=owners) - if tx and results: + if tx is not None and results is not None: if return_type is not None: try: - parameterized_results = [ContractParameter.AsParameterType(ContractParameterType.FromString(return_type), item).ToJson() for item in results] + parameterized_results = [ContractParameter.AsParameterType(ContractParameterType.FromString(return_type), item).ToJson() for item in + results] except ValueError: logger.debug("invalid return type") return False @@ -235,7 +236,7 @@ def command_desc(self): p6 = ParameterDesc('--from-addr', 'source address to take fee funds from (if not specified, take first address in wallet)', optional=True) p7 = ParameterDesc('--fee', 'Attach GAS amount to give your transaction priority (> 0.001) e.g. --fee=0.01', optional=True) p8 = ParameterDesc('--owners', 'list of NEO addresses indicating the transaction owners e.g. --owners=[address1,address2]', optional=True) - p9 = ParameterDesc('--return-type', 'override the return parameter type e.g. --return-type=02', optional=True) + p9 = ParameterDesc('--return-type', 'override the return parameter type e.g. --return-type=02', optional=True) p10 = ParameterDesc('--tx-attr', 'a list of transaction attributes to attach to the transaction\n\n' f"{' ':>17} See: http://docs.neo.org/en-us/network/network-protocol.html section 4 for a description of possible attributes\n\n" @@ -303,7 +304,7 @@ def execute(self, arguments): return False tx, fee, results, num_ops, engine_success = test_invoke(contract_script, wallet, [], from_addr=from_addr) - if tx and results: + if tx is not None and results is not None: print( "\n-------------------------------------------------------------------------------------------------------------------------------------") print("Test deploy invoke successful") diff --git a/neo/Prompt/Commands/Send.py b/neo/Prompt/Commands/Send.py index a351cb37d..870c7400f 100755 --- a/neo/Prompt/Commands/Send.py +++ b/neo/Prompt/Commands/Send.py @@ -1,7 +1,6 @@ from neo.Core.TX.Transaction import TransactionOutput, ContractTransaction, TXFeeError from neo.Core.TX.TransactionAttribute import TransactionAttribute, TransactionAttributeUsage from neo.SmartContract.ContractParameterContext import ContractParametersContext -from neo.Network.NodeLeader import NodeLeader from neo.Prompt.Utils import get_arg, get_from_addr, get_asset_id, lookup_addr_str, get_tx_attr_from_args, \ get_owners_from_params, get_fee, get_change_addr, get_asset_amount from neo.Prompt.Commands.Tokens import do_token_transfer, amount_from_string @@ -9,13 +8,14 @@ from neo.Wallets.NEP5Token import NEP5Token from neo.Core.Fixed8 import Fixed8 import json -from prompt_toolkit import prompt import traceback from neo.Prompt.PromptData import PromptData from neo.Prompt.CommandBase import CommandBase, CommandDesc, ParameterDesc from logzero import logger from neo.Prompt.PromptPrinter import prompt_print as print from neo.Core.Blockchain import Blockchain +from neo.Network.nodemanager import NodeManager +from neo.Network.common import blocking_prompt as prompt class CommandWalletSend(CommandBase): @@ -320,8 +320,8 @@ def process_transaction(wallet, contract_tx, scripthash_from=None, scripthash_ch if context.Completed: tx.scripts = context.GetScripts() - relayed = NodeLeader.Instance().Relay(tx) - + nodemgr = NodeManager() + relayed = nodemgr.relay(tx) if relayed: wallet.SaveTransaction(tx) @@ -364,7 +364,8 @@ def parse_and_sign(wallet, jsn): print("will send tx: %s " % json.dumps(tx.ToJson(), indent=4)) - relayed = NodeLeader.Instance().Relay(tx) + nodemgr = NodeManager() + relayed = nodemgr.relay(tx) if relayed: print("Relayed Tx: %s " % tx.Hash.ToString()) diff --git a/neo/Prompt/Commands/Show.py b/neo/Prompt/Commands/Show.py index cdd89b8da..ecfe0750f 100644 --- a/neo/Prompt/Commands/Show.py +++ b/neo/Prompt/Commands/Show.py @@ -8,10 +8,11 @@ from neo.Core.UInt256 import UInt256 from neo.Core.UInt160 import UInt160 from neo.IO.MemoryStream import StreamManager -from neo.Network.NodeLeader import NodeLeader from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB from neo.logging import log_manager from neo.Prompt.PromptPrinter import prompt_print as print +from neo.Network.nodemanager import NodeManager +from neo.Network.syncmanager import SyncManager import json logger = log_manager.getLogger() @@ -161,18 +162,63 @@ def __init__(self): super().__init__() def execute(self, arguments=None): - if len(NodeLeader.Instance().Peers) > 0: - out = "Total Connected: %s\n" % len(NodeLeader.Instance().Peers) - for i, peer in enumerate(NodeLeader.Instance().Peers): - out += f"Peer {i} {peer.Name():>12} - {peer.address:>21} - IO {peer.IOStats()}\n" - print(out) - return out + show_verbose = get_arg(arguments) == 'verbose' + show_queued = get_arg(arguments) == 'queued' + show_known = get_arg(arguments) == 'known' + show_bad = get_arg(arguments) == 'bad' + + nodemgr = NodeManager() + len_nodes = len(nodemgr.nodes) + out = "" + if len_nodes > 0: + out = f"Connected: {len_nodes} of max {nodemgr.max_clients}\n" + for i, node in enumerate(nodemgr.nodes): + out += f"Peer {i} {node.version.user_agent:>12} {node.address:>21} height: {node.best_height:>8}\n" else: - print("Not connected yet\n") - return + print("No nodes connected yet\n") + + if show_verbose: + out += f"\n" + out += f"Addresses in queue: {len(nodemgr.queued_addresses)}\n" + out += f"Known addresses: {len(nodemgr.known_addresses)}\n" + out += f"Bad addresses: {len(nodemgr.bad_addresses)}\n" + + if show_queued: + out += f"\n" + if len(nodemgr.queued_addresses) == 0: + out += "No queued addresses" + else: + out += f"Queued addresses:\n" + for addr in nodemgr.queued_addresses: + out += f"{addr}\n" + + if show_known: + out += f"\n" + if len(nodemgr.known_addresses) == 0: + out += "No known addresses other than connect peers" + else: + out += f"Known addresses:\n" + for addr in nodemgr.known_addresses: + out += f"{addr}\n" + + if show_bad: + out += f"\n" + if len(nodemgr.bad_addresses) == 0: + out += "No bad addresses" + else: + out += f"Bad addresses:\n" + for addr in nodemgr.bad_addresses: + out += f"{addr}\n" + print(out) + return out def command_desc(self): - return CommandDesc('nodes', 'show connected peers') + p1 = ParameterDesc('verbose', 'also show the number of queued, known, and bad addresses', optional=True) + p2 = ParameterDesc('queued', 'also list the queued addresses', optional=True) + p3 = ParameterDesc('known', 'also list the known addresses', optional=True) + p4 = ParameterDesc('bad', 'also list the bad addresses', optional=True) + params = [p1, p2, p3, p4] + return CommandDesc('nodes', 'show connected peers and their blockheight', params=params) class CommandShowState(CommandBase): @@ -196,8 +242,10 @@ def execute(self, arguments=None): bpm = diff / mins tps = Blockchain.Default().TXProcessed / secs + syncmngr = SyncManager() + out = "Progress: %s / %s\n" % (height, headers) - out += "Block-cache length %s\n" % Blockchain.Default().BlockCacheCount + out += "Block-cache length %s\n" % len(syncmngr.block_cache) out += "Blocks since program start %s\n" % diff out += "Time elapsed %s mins\n" % mins out += "Blocks per min %s \n" % bpm diff --git a/neo/Prompt/Commands/Tokens.py b/neo/Prompt/Commands/Tokens.py index 415f21781..36d07593e 100644 --- a/neo/Prompt/Commands/Tokens.py +++ b/neo/Prompt/Commands/Tokens.py @@ -1,8 +1,7 @@ from neo.Prompt.Commands.Invoke import InvokeContract, InvokeWithTokenVerificationScript -from neo.Wallets.NEP5Token import NEP5Token from neo.Core.Fixed8 import Fixed8 from neo.Core.UInt160 import UInt160 -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from decimal import Decimal from neo.Core.TX.TransactionAttribute import TransactionAttribute import binascii @@ -197,7 +196,7 @@ def execute(self, arguments): logger.error(traceback.format_exc()) return False - if tx and results: + if tx is not None and results is not None: vm_result = results[0].GetBigInteger() if vm_result == 1: print("\n-----------------------------------------------------------") @@ -318,7 +317,7 @@ def execute(self, arguments): tx, fee, results = token.Approve(wallet, from_addr, to_addr, decimal_amount) - if tx and results: + if tx is not None and results is not None: if results[0].GetBigInteger() == 1: print("\n-----------------------------------------------------------") print(f"Approve allowance of {amount} {token.symbol} from {from_addr} to {to_addr}") @@ -453,7 +452,7 @@ def execute(self, arguments): logger.debug("invalid fee") return False - return token_mint(token, wallet, to_addr, asset_attachments=asset_attachments, fee=fee, invoke_attrs=invoke_attrs) + return token_mint(token, wallet, to_addr, asset_attachments=asset_attachments, fee=fee, invoke_attrs=invoke_attrs) def command_desc(self): p1 = ParameterDesc('symbol', 'token symbol or script hash') @@ -508,8 +507,8 @@ def execute(self, arguments): tx, fee, results = token.CrowdsaleRegister(wallet, addr_list) - if tx and results: - if results[0].GetBigInteger() > 0: + if tx is not None and results is not None: + if len(results) > 0 and results[0].GetBigInteger() > 0: print("\n-----------------------------------------------------------") print("[%s] Will register addresses for crowdsale: %s " % (token.symbol, register_addr)) print("Invocation Fee: %s " % (fee.value / Fixed8.D)) @@ -676,7 +675,7 @@ def token_get_allowance(wallet, token_str, from_addr, to_addr, verbose=False): tx, fee, results = token.Allowance(wallet, from_addr, to_addr) - if tx and results: + if tx is not None and results is not None: allowance = results[0].GetBigInteger() if verbose: print("%s allowance for %s from %s : %s " % (token.symbol, from_addr, to_addr, allowance)) @@ -696,8 +695,8 @@ def token_mint(token, wallet, to_addr, asset_attachments=[], fee=Fixed8.Zero(), tx, fee, results = token.Mint(wallet, to_addr, asset_attachments, invoke_attrs=invoke_attrs) - if tx and results: - if results[0] is not None: + if tx is not None and results is not None: + if len(results) > 0 and results[0] is not None: print("\n-----------------------------------------------------------") print(f"[{token.symbol}] Will mint tokens to address: {to_addr}") print(f"Invocation Fee: {fee.value / Fixed8.D}") diff --git a/neo/Prompt/Commands/Wallet.py b/neo/Prompt/Commands/Wallet.py index 3cd648439..2c5f89955 100644 --- a/neo/Prompt/Commands/Wallet.py +++ b/neo/Prompt/Commands/Wallet.py @@ -3,15 +3,15 @@ from neo.Core.TX.Transaction import TransactionOutput from neo.Core.TX.TransactionAttribute import TransactionAttribute, TransactionAttributeUsage from neo.SmartContract.ContractParameterContext import ContractParametersContext -from neo.Network.NodeLeader import NodeLeader from neo.Prompt import Utils as PromptUtils from neo.Wallets.utils import to_aes_key from neo.Implementations.Wallets.peewee.UserWallet import UserWallet from neo.Core.Fixed8 import Fixed8 from neo.Core.UInt160 import UInt160 -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt import json import os +import asyncio from neo.Prompt.CommandBase import CommandBase, CommandDesc, ParameterDesc from neo.Prompt.PromptData import PromptData from neo.Prompt.Commands.Send import CommandWalletSend, CommandWalletSendMany, CommandWalletSign @@ -22,6 +22,7 @@ from neo.logging import log_manager from neo.Core.Utils import isValidPublicAddress from neo.Prompt.PromptPrinter import prompt_print as print +from neo.Network.nodemanager import NodeManager logger = log_manager.getLogger() @@ -60,7 +61,7 @@ def execute(self, arguments): print("Please open a wallet") return - if not item or item == 'verbose': + if not item: wallet.pretty_print(item) return wallet @@ -124,7 +125,7 @@ def execute(self, arguments): return if PromptData.Wallet: - PromptData.Prompt.start_wallet_loop() + asyncio.create_task(PromptData.Wallet.sync_wallet(start_block=PromptData.Wallet._current_height)) return PromptData.Wallet def command_desc(self): @@ -160,9 +161,8 @@ def execute(self, arguments): try: PromptData.Wallet = UserWallet.Open(path, password_key) - - PromptData.Prompt.start_wallet_loop() print("Opened wallet at %s" % path) + asyncio.create_task(PromptData.Wallet.sync_wallet(start_block=PromptData.Wallet._current_height)) return PromptData.Wallet except Exception as e: print("Could not open wallet: %s" % e) @@ -190,8 +190,9 @@ def __init__(self): super().__init__() def execute(self, arguments=None): - print("Wallet %s " % json.dumps(PromptData.Wallet.ToJson(verbose=True), indent=4)) - return True + wallet = PromptData.Wallet + wallet.pretty_print(verbose=True) + return wallet def command_desc(self): return CommandDesc('verbose', 'show additional wallet details') @@ -227,16 +228,12 @@ def __init__(self): super().__init__() def execute(self, arguments): - PromptData.Prompt.stop_wallet_loop() - start_block = PromptUtils.get_arg(arguments, 0, convert_to_int=True) if not start_block or start_block < 0: start_block = 0 print(f"Restarting at block {start_block}") - - PromptData.Wallet.Rebuild(start_block) - - PromptData.Prompt.start_wallet_loop() + task = asyncio.create_task(PromptData.Wallet.sync_wallet(start_block, rebuild=True)) + return task def command_desc(self): p1 = ParameterDesc('start_block', 'block number to start the resync at', optional=True) @@ -284,6 +281,22 @@ def command_desc(self): ######################################################################### ######################################################################### +async def sync_wallet(start_block, rebuild=False): + Blockchain.Default().PersistCompleted.on_change -= PromptData.Wallet.ProcessNewBlock + + if rebuild: + PromptData.Wallet.Rebuild(start_block) + while True: + # trying with 100, might need to lower if processing takes too long + PromptData.Wallet.ProcessBlocks(block_limit=100) + + if PromptData.Wallet.IsSynced: + break + # give some time to other tasks + await asyncio.sleep(0.05) + + Blockchain.Default().PersistCompleted.on_change += PromptData.Wallet.ProcessNewBlock + def ClaimGas(wallet, from_addr_str=None, to_addr_str=None): """ @@ -368,7 +381,8 @@ def ClaimGas(wallet, from_addr_str=None, to_addr_str=None): print("claim tx: %s " % json.dumps(claim_tx.ToJson(), indent=4)) - relayed = NodeLeader.Instance().Relay(claim_tx) + nodemgr = NodeManager() + relayed = nodemgr.relay(claim_tx) if relayed: print("Relayed Tx: %s " % claim_tx.Hash.ToString()) diff --git a/neo/Prompt/Commands/WalletAddress.py b/neo/Prompt/Commands/WalletAddress.py index 459e7a0f5..c0a6326f5 100644 --- a/neo/Prompt/Commands/WalletAddress.py +++ b/neo/Prompt/Commands/WalletAddress.py @@ -6,14 +6,12 @@ from neo.Core.Utils import isValidPublicAddress from neo.Core.Fixed8 import Fixed8 from neo.SmartContract.ContractParameterContext import ContractParametersContext -from neo.Network.NodeLeader import NodeLeader -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.Core.Blockchain import Blockchain from neo.Core.TX.Transaction import ContractTransaction from neo.Core.TX.Transaction import TransactionOutput from neo.Prompt.PromptPrinter import prompt_print as print - -import sys +from neo.Network.nodemanager import NodeManager class CommandWalletAddress(CommandBase): @@ -277,7 +275,9 @@ def SplitUnspentCoin(wallet, asset_id, from_addr, index, divisions, fee=Fixed8.Z if ctx.Completed: contract_tx.scripts = ctx.GetScripts() - relayed = NodeLeader.Instance().Relay(contract_tx) + nodemgr = NodeManager() + # this blocks, consider moving this wallet function to async instead + relayed = nodemgr.relay(contract_tx) if relayed: wallet.SaveTransaction(contract_tx) diff --git a/neo/Prompt/Commands/WalletExport.py b/neo/Prompt/Commands/WalletExport.py index c0f4e6cf0..0bc6e2844 100644 --- a/neo/Prompt/Commands/WalletExport.py +++ b/neo/Prompt/Commands/WalletExport.py @@ -1,7 +1,7 @@ from neo.Prompt.CommandBase import CommandBase, CommandDesc, ParameterDesc from neo.Prompt import Utils as PromptUtils from neo.Prompt.PromptData import PromptData -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.Prompt.PromptPrinter import prompt_print as print diff --git a/neo/Prompt/Commands/WalletImport.py b/neo/Prompt/Commands/WalletImport.py index a603e9097..265eae7b8 100644 --- a/neo/Prompt/Commands/WalletImport.py +++ b/neo/Prompt/Commands/WalletImport.py @@ -5,7 +5,7 @@ from neo.Prompt.Commands.LoadSmartContract import ImportContractAddr from neo.Prompt import Utils as PromptUtils from neo.Core.KeyPair import KeyPair -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.Core.Utils import isValidPublicAddress from neo.Core.UInt160 import UInt160 from neo.Core.Cryptography.Crypto import Crypto diff --git a/neo/Prompt/Commands/tests/test_address_commands.py b/neo/Prompt/Commands/tests/test_address_commands.py index a2c1b2df0..e73536d8f 100644 --- a/neo/Prompt/Commands/tests/test_address_commands.py +++ b/neo/Prompt/Commands/tests/test_address_commands.py @@ -7,7 +7,8 @@ from neo.Core.Fixed8 import Fixed8 from mock import patch from io import StringIO -import os +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode class UserWalletTestCase(UserWalletTestCaseBase): @@ -105,35 +106,38 @@ def test_wallet_alias(self): self.assertIn('mine', [n.Title for n in PromptData.Wallet.NamedAddr]) def test_6_split_unspent(self): - # os.environ["NEOPYTHON_UNITTEST"] = "1" wallet = self.GetWallet1(recreate=True) addr = wallet.ToScriptHash('AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3') - # # bad inputs - # tx = SplitUnspentCoin(None, self.NEO, addr, 0, 2) - # self.assertEqual(tx, None) - # - # tx = SplitUnspentCoin(wallet, self.NEO, addr, 3, 2) - # self.assertEqual(tx, None) - # - # tx = SplitUnspentCoin(wallet, 'bla', addr, 0, 2) - # self.assertEqual(tx, None) - - # should be ok - with patch('neo.Prompt.Commands.WalletAddress.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - tx = SplitUnspentCoin(wallet, self.NEO, addr, 0, 2) - self.assertIsNotNone(tx) - - # # rebuild wallet and try with non-even amount of neo, should be split into integer values of NEO - # wallet = self.GetWallet1(True) - # tx = SplitUnspentCoin(wallet, self.NEO, addr, 0, 3) - # self.assertIsNotNone(tx) - # self.assertEqual([Fixed8.FromDecimal(17), Fixed8.FromDecimal(17), Fixed8.FromDecimal(16)], [item.Value for item in tx.outputs]) - # - # # try with gas - # wallet = self.GetWallet1(True) - # tx = SplitUnspentCoin(wallet, self.GAS, addr, 0, 3) - # self.assertIsNotNone(tx) + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + # bad inputs + tx = SplitUnspentCoin(None, self.NEO, addr, 0, 2) + self.assertEqual(tx, None) + + tx = SplitUnspentCoin(wallet, self.NEO, addr, 3, 2) + self.assertEqual(tx, None) + + tx = SplitUnspentCoin(wallet, 'bla', addr, 0, 2) + self.assertEqual(tx, None) + + # should be ok + with patch('neo.Prompt.Commands.WalletAddress.prompt', return_value=self.wallet_1_pass()): + tx = SplitUnspentCoin(wallet, self.NEO, addr, 0, 2) + self.assertIsNotNone(tx) + + # rebuild wallet and try with non-even amount of neo, should be split into integer values of NEO + wallet = self.GetWallet1(True) + tx = SplitUnspentCoin(wallet, self.NEO, addr, 0, 3) + self.assertIsNotNone(tx) + self.assertEqual([Fixed8.FromDecimal(17), Fixed8.FromDecimal(17), Fixed8.FromDecimal(16)], [item.Value for item in tx.outputs]) + + # try with gas + wallet = self.GetWallet1(True) + tx = SplitUnspentCoin(wallet, self.GAS, addr, 0, 3) + self.assertIsNotNone(tx) def test_7_create_address(self): # no wallet @@ -262,27 +266,36 @@ def test_wallet_split(self): self.assertIsNone(res) self.assertIn("Fee could not be subtracted from outputs", mock_print.getvalue()) - # test wallet split with error during tx relay + # # test wallet split with error during tx relay + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + with patch('neo.Prompt.Commands.WalletAddress.prompt', side_effect=[self.wallet_1_pass()]): - with patch('neo.Network.NodeLeader.NodeLeader.Relay', side_effect=[None]): + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(False)): with patch('sys.stdout', new=StringIO()) as mock_print: args = ['address', 'split', self.wallet_1_addr, 'neo', '0', '2'] res = CommandWallet().execute(args) self.assertIsNone(res) self.assertIn("Could not relay tx", mock_print.getvalue()) + # we have to clear the mempool because the previous test alread put a TX with the same hash in the mempool and so it will not try to relay again + nodemgr.mempool.reset() + # test wallet split neo successful with patch('neo.Prompt.Commands.WalletAddress.prompt', side_effect=[self.wallet_1_pass()]): - args = ['address', 'split', self.wallet_1_addr, 'neo', '0', '2'] - tx = CommandWallet().execute(args) - self.assertTrue(tx) - self.assertIsInstance(tx, ContractTransaction) - self.assertEqual([Fixed8.FromDecimal(25), Fixed8.FromDecimal(25)], [item.Value for item in tx.outputs]) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['address', 'split', self.wallet_1_addr, 'neo', '0', '2'] + tx = CommandWallet().execute(args) + self.assertIsInstance(tx, ContractTransaction) + self.assertEqual([Fixed8.FromDecimal(25), Fixed8.FromDecimal(25)], [item.Value for item in tx.outputs]) # test wallet split gas successful with patch('neo.Prompt.Commands.WalletAddress.prompt', side_effect=[self.wallet_1_pass()]): - args = ['address', 'split', self.wallet_1_addr, 'gas', '0', '3'] - tx = CommandWallet().execute(args) - self.assertTrue(tx) - self.assertIsInstance(tx, ContractTransaction) - self.assertEqual(len(tx.outputs), 3) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['address', 'split', self.wallet_1_addr, 'gas', '0', '3'] + tx = CommandWallet().execute(args) + self.assertIsInstance(tx, ContractTransaction) + self.assertEqual(len(tx.outputs), 3) + + nodemgr.reset_for_test() diff --git a/neo/Prompt/Commands/tests/test_claim_command.py b/neo/Prompt/Commands/tests/test_claim_command.py index a9b5eb38c..97164565f 100644 --- a/neo/Prompt/Commands/tests/test_claim_command.py +++ b/neo/Prompt/Commands/tests/test_claim_command.py @@ -6,7 +6,8 @@ from neo.Prompt.Commands.Wallet import ClaimGas from neo.Core.Fixed8 import Fixed8 from neo.Core.TX.ClaimTransaction import ClaimTransaction -from neo.Prompt.PromptPrinter import pp +from neo.Network.node import NeoNode +from neo.Network.nodemanager import NodeManager import shutil from mock import patch from io import StringIO @@ -117,19 +118,29 @@ def test_4_keyboard_interupt(self): wallet = self.GetWallet1() claim_tx, relayed = ClaimGas(wallet) - self.assertEqual(claim_tx, None) - self.assertFalse(relayed) - self.assertIn("Claim transaction cancelled", mock_print.getvalue()) + self.assertEqual(claim_tx, None) + self.assertFalse(relayed) + self.assertIn("Claim transaction cancelled", mock_print.getvalue()) def test_5_wallet_claim_ok(self): - with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - wallet = self.GetWallet1() - claim_tx, relayed = ClaimGas(wallet) - self.assertIsInstance(claim_tx, ClaimTransaction) - self.assertTrue(relayed) + wallet = self.GetWallet1() + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Wallet.prompt', return_value=self.wallet_1_pass()): + claim_tx, relayed = ClaimGas(wallet) + self.assertIsInstance(claim_tx, ClaimTransaction) + self.assertTrue(relayed) def test_6_no_wallet(self): + with patch('neo.Prompt.Commands.Wallet.prompt', return_value=self.wallet_1_pass()): + claim_tx, relayed = ClaimGas(None) + self.assertEqual(claim_tx, None) + self.assertFalse(relayed) + + def test_7_no_wallet(self): claim_tx, relayed = ClaimGas(None) self.assertEqual(claim_tx, None) self.assertFalse(relayed) diff --git a/neo/Prompt/Commands/tests/test_config_commands.py b/neo/Prompt/Commands/tests/test_config_commands.py index c38085404..c1dbde1f0 100644 --- a/neo/Prompt/Commands/tests/test_config_commands.py +++ b/neo/Prompt/Commands/tests/test_config_commands.py @@ -2,8 +2,6 @@ from neo.Settings import settings from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase from neo.Prompt.Commands.Config import CommandConfig -from neo.Network.NodeLeader import NodeLeader, NeoNode -from neo.Network.address import Address from mock import patch from io import StringIO from neo.Prompt.PromptPrinter import pp @@ -29,7 +27,8 @@ def test_config_output(self): from neo.Implementations.Wallets.peewee.UserWallet import UserWallet args = ['output'] - with patch('neo.Prompt.Commands.Config.prompt', side_effect=[1, 1, 1, "a", "\n", "\n"]): # tests changing the level and keeping the current level. Entering "a" has no effect. + with patch('neo.Prompt.Commands.Config.prompt', + side_effect=[1, 1, 1, "a", "\n", "\n"]): # tests changing the level and keeping the current level. Entering "a" has no effect. res = CommandConfig().execute(args) self.assertTrue(res) self.assertEqual(res['generic'], "DEBUG") @@ -37,7 +36,6 @@ def test_config_output(self): self.assertEqual(res['db'], "DEBUG") self.assertEqual(res['peewee'], "ERROR") self.assertEqual(res['network'], "INFO") - self.assertEqual(res['network.verbose'], "INFO") # test with keyboard interrupt with patch('sys.stdout', new=StringIO()) as mock_print: @@ -115,59 +113,6 @@ def test_config_vm_log(self): res = CommandConfig().execute(args) self.assertFalse(res) - def test_config_node_requests(self): - # test no input - args = ['node-requests'] - res = CommandConfig().execute(args) - self.assertFalse(res) - - # test updating block request size - # first make sure we have a predictable state - NodeLeader.Instance().Reset() - leader = NodeLeader.Instance() - leader.ADDRS = ["127.0.0.1:20333", "127.0.0.2:20334"] - leader.DEAD_ADDRS = ["127.0.0.1:20335"] - - # test slow setting - args = ['node-requests', 'slow'] - res = CommandConfig().execute(args) - self.assertTrue(res) - - # test normal setting - args = ['node-requests', 'normal'] - res = CommandConfig().execute(args) - self.assertTrue(res) - - # test fast setting - args = ['node-requests', 'fast'] - res = CommandConfig().execute(args) - self.assertTrue(res) - - # test bad setting - args = ['node-requests', 'blah'] - res = CommandConfig().execute(args) - self.assertFalse(res) - - # test custom setting - args = ['node-requests', '20', '6000'] - res = CommandConfig().execute(args) - self.assertTrue(res) - - # test bad custom input - args = ['node-requests', '20', 'blah'] - res = CommandConfig().execute(args) - self.assertFalse(res) - - # test bad custom setting: breqmax should be greater than breqpart - args = ['node-requests', '20', '10'] - res = CommandConfig().execute(args) - self.assertFalse(res) - - # test another bad custom setting: breqpart should not exceed 500 - args = ['node-requests', '600', '5000'] - res = CommandConfig().execute(args) - self.assertFalse(res) - def test_config_maxpeers(self): # test no input and verify output confirming current maxpeers with patch('sys.stdout', new=StringIO()) as mock_print: @@ -189,7 +134,7 @@ def test_config_maxpeers(self): args = ['maxpeers', "blah"] res = CommandConfig().execute(args) self.assertFalse(res) - self.assertIn("Please supply a positive integer for maxpeers", mock_print.getvalue()) + self.assertIn("Invalid argument", mock_print.getvalue()) # test negative number with patch('sys.stdout', new=StringIO()) as mock_print: @@ -198,43 +143,42 @@ def test_config_maxpeers(self): self.assertFalse(res) self.assertIn("Please supply a positive integer for maxpeers", mock_print.getvalue()) - # test if the new maxpeers < settings.CONNECTED_PEER_MAX - # first make sure we have a predictable state - NodeLeader.Instance().Reset() - leader = NodeLeader.Instance() - addr1 = Address("127.0.0.1:20333") - addr2 = Address("127.0.0.1:20334") - leader.ADDRS = [addr1, addr2] - leader.DEAD_ADDRS = [Address("127.0.0.1:20335")] - test_node = NeoNode() - test_node.host = "127.0.0.1" - test_node.port = 20333 - test_node.address = Address("127.0.0.1:20333") - test_node2 = NeoNode() - test_node2.host = "127.0.0.1" - test_node2.port = 20333 - test_node2.address = Address("127.0.0.1:20334") - leader.Peers = [test_node, test_node2] - - with patch("neo.Network.NeoNode.NeoNode.Disconnect") as mock_disconnect: - # first test if the number of connected peers !< new maxpeers - with patch('sys.stdout', new=StringIO()) as mock_print: - args = ['maxpeers', "4"] - res = CommandConfig().execute(args) - self.assertTrue(res) - self.assertEqual(len(leader.Peers), 2) - self.assertFalse(mock_disconnect.called) - self.assertIn(f"Maxpeers set to {settings.CONNECTED_PEER_MAX}", mock_print.getvalue()) - - # now test if the number of connected peers < new maxpeers - with patch('sys.stdout', new=StringIO()) as mock_print: - args = ['maxpeers', "1"] - res = CommandConfig().execute(args) - self.assertTrue(res) - self.assertEqual(len(leader.Peers), 1) - self.assertEqual(leader.Peers[0].address, test_node.address) - self.assertTrue(mock_disconnect.called) - self.assertIn(f"Maxpeers set to {settings.CONNECTED_PEER_MAX}", mock_print.getvalue()) + def test_config_minpeers(self): + # test no input and verify output confirming current minpeers + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['minpeers'] + res = CommandConfig().execute(args) + self.assertFalse(res) + self.assertIn(f"Maintaining minpeers at {settings.CONNECTED_PEER_MIN}", mock_print.getvalue()) + + # test changing the number of minpeers + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['minpeers', "6"] + res = CommandConfig().execute(args) + self.assertTrue(res) + self.assertEqual(int(res), settings.CONNECTED_PEER_MIN) + self.assertIn(f"Minpeers set to {settings.CONNECTED_PEER_MIN}", mock_print.getvalue()) + + # test bad input + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['minpeers', "blah"] + res = CommandConfig().execute(args) + self.assertFalse(res) + self.assertIn("Please supply a positive integer for minpeers", mock_print.getvalue()) + + # test negative number + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['minpeers', "-1"] + res = CommandConfig().execute(args) + self.assertFalse(res) + self.assertIn("Please supply a positive integer for minpeers", mock_print.getvalue()) + + # test minpeers greater than maxpeers + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['minpeers', f"{settings.CONNECTED_PEER_MAX + 1}"] + res = CommandConfig().execute(args) + self.assertFalse(res) + self.assertIn("minpeers setting cannot be bigger than maxpeers setting", mock_print.getvalue()) def test_config_nep8(self): # test with missing flag argument diff --git a/neo/Prompt/Commands/tests/test_sc_commands.py b/neo/Prompt/Commands/tests/test_sc_commands.py index ab8be2b6d..7e7fa54a8 100644 --- a/neo/Prompt/Commands/tests/test_sc_commands.py +++ b/neo/Prompt/Commands/tests/test_sc_commands.py @@ -11,6 +11,8 @@ from io import StringIO from boa.compiler import Compiler from neo.Settings import settings +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode class CommandSCTestCase(WalletFixtureTestCase): @@ -135,14 +137,14 @@ def test_sc_buildrun(self): self.assertIn("Test deploy invoke successful", mock_print.getvalue()) # test successful build and run with prompted input - # PromptData.Wallet = self.GetWallet1(recreate=True) - # with patch('sys.stdout', new=StringIO()) as mock_print: - # with patch('neo.Prompt.Utils.PromptSession.prompt', side_effect=['remove', 'AG4GfwjnvydAZodm4xEDivguCtjCFzLcJy', '3']): - # args = ['build_run', 'neo/Prompt/Commands/tests/SampleSC.py', 'True', 'False', 'False', '070502', '02', '--i'] - # tx, result, total_ops, engine = CommandSC().execute(args) - # self.assertTrue(tx) - # self.assertEqual(str(result[0]), '0') - # self.assertIn("Test deploy invoke successful", mock_print.getvalue()) + PromptData.Wallet = self.GetWallet1(recreate=True) + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Utils.prompt', side_effect=['remove', 'AG4GfwjnvydAZodm4xEDivguCtjCFzLcJy', '3']): + args = ['build_run', 'neo/Prompt/Commands/tests/SampleSC.py', 'True', 'False', 'False', '070502', '02', '--i'] + tx, result, total_ops, engine = CommandSC().execute(args) + self.assertTrue(tx) + self.assertEqual(str(result[0]), '0') + self.assertIn("Test deploy invoke successful", mock_print.getvalue()) # test invoke failure (SampleSC requires three inputs) PromptData.Wallet = self.GetWallet1(recreate=True) @@ -351,7 +353,8 @@ def test_sc_deploy(self): args = ['deploy', path_dir + 'SampleSC.avm', 'True', 'False', 'False', '070502', '02'] res = CommandSC().execute(args) self.assertFalse(res) - self.assertIn("Deploy Invoke TX Fee: 0.00387", mock_print.getvalue()) # notice the required fee is now greater than the low priority threshold + self.assertIn("Deploy Invoke TX Fee: 0.00387", + mock_print.getvalue()) # notice the required fee is now greater than the low priority threshold self.assertTrue(mock_print.getvalue().endswith('Insufficient funds\n')) def test_sc_invoke(self): @@ -450,13 +453,17 @@ def test_sc_invoke(self): self.assertIn("Integer", mock_print.getvalue()) # test ok - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.SC.prompt', side_effect=[self.wallet_3_pass()]): - args = ['invoke', token_hash_str, 'symbol', '[]', '--fee=0.001'] - res = CommandSC().execute(args) - # not the best check, but will do for now - self.assertTrue(res) - self.assertIn("Priority Fee (0.001) + Invoke TX Fee (0.0001) = 0.0011", mock_print.getvalue()) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.SC.prompt', side_effect=[self.wallet_3_pass()]): + args = ['invoke', token_hash_str, 'symbol', '[]', '--fee=0.001'] + res = CommandSC().execute(args) + # not the best check, but will do for now + self.assertTrue(res) + self.assertIn("Priority Fee (0.001) + Invoke TX Fee (0.0001) = 0.0011", mock_print.getvalue()) def test_sc_debugstorage(self): # test with insufficient parameters diff --git a/neo/Prompt/Commands/tests/test_send_commands.py b/neo/Prompt/Commands/tests/test_send_commands.py index 64a2926d4..ced4ba98a 100644 --- a/neo/Prompt/Commands/tests/test_send_commands.py +++ b/neo/Prompt/Commands/tests/test_send_commands.py @@ -9,9 +9,9 @@ from neo.Prompt.PromptData import PromptData import shutil from mock import patch -import json +from neo.Network.node import NeoNode +from neo.Network.nodemanager import NodeManager from io import StringIO -from neo.Prompt.PromptPrinter import pp class UserWalletTestCase(WalletFixtureTestCase): @@ -36,42 +36,55 @@ def tearDown(cls): PromptData.Wallet = None def test_send_neo(self): - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'neo', self.watch_addr_str, '50'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'neo', self.watch_addr_str, '50'] + res = Wallet.CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("Sending with fee: 0", mock_print.getvalue()) + self.assertTrue(res) + self.assertIn("Sending with fee: 0", mock_print.getvalue()) def test_send_gas(self): - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'gas', self.watch_addr_str, '5'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'gas', self.watch_addr_str, '5'] + res = Wallet.CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("Sending with fee: 0", mock_print.getvalue()) + self.assertTrue(res) + self.assertIn("Sending with fee: 0", mock_print.getvalue()) def test_send_with_fee_and_from_addr(self): - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'neo', self.watch_addr_str, '1', '--from-addr=AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3', '--fee=0.005'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'neo', self.watch_addr_str, '1', '--from-addr=AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3', '--fee=0.005'] - self.assertTrue(res) # verify successful tx + res = Wallet.CommandWallet().execute(args) - json_res = res.ToJson() - self.assertEqual(self.watch_addr_str, json_res['vout'][0]['address']) # verify correct address_to - self.assertEqual(self.wallet_1_addr, json_res['vout'][1]['address']) # verify correct address_from - self.assertEqual(json_res['net_fee'], "0.005") # verify correct fee - self.assertIn("Sending with fee: 0.005", mock_print.getvalue()) + self.assertTrue(res) # verify successful tx + + json_res = res.ToJson() + self.assertEqual(self.watch_addr_str, json_res['vout'][0]['address']) # verify correct address_to + self.assertEqual(self.wallet_1_addr, json_res['vout'][1]['address']) # verify correct address_from + self.assertEqual(json_res['net_fee'], "0.005") # verify correct fee + self.assertIn("Sending with fee: 0.005", mock_print.getvalue()) def test_send_no_wallet(self): with patch('sys.stdout', new=StringIO()) as mock_print: @@ -198,20 +211,25 @@ def test_send_token_bad(self): self.assertIn("Could not find the contract hash", mock_print.getvalue()) def test_send_token_ok(self): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - with patch('sys.stdout', new=StringIO()) as mock_print: - PromptData.Wallet = self.GetWallet1(recreate=True) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - token_hash = '31730cc9a1844891a3bafd1aa929a4142860d8d3' - ImportToken(PromptData.Wallet, token_hash) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + with patch('sys.stdout', new=StringIO()) as mock_print: + PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'NXT4', self.watch_addr_str, '30', '--from-addr=%s' % self.wallet_1_addr] + token_hash = '31730cc9a1844891a3bafd1aa929a4142860d8d3' + ImportToken(PromptData.Wallet, token_hash) - res = Wallet.CommandWallet().execute(args) + args = ['send', 'NXT4', self.watch_addr_str, '30', '--from-addr=%s' % self.wallet_1_addr] - self.assertTrue(res) - self.assertIn("Will transfer 30.00000000 NXT4 from AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3 to AGYaEi3W6ndHPUmW7T12FFfsbQ6DWymkEm", - mock_print.getvalue()) + res = Wallet.CommandWallet().execute(args) + + self.assertTrue(res) + self.assertIn("Will transfer 30.00000000 NXT4 from AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3 to AGYaEi3W6ndHPUmW7T12FFfsbQ6DWymkEm", + mock_print.getvalue()) def test_insufficient_funds(self): @@ -282,35 +300,51 @@ def test_owners(self, mock): self.assertTrue(mock.called) def test_attributes(self): - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'gas', self.watch_addr_str, '2', '--tx-attr={"usage":241,"data":"This is a remark"}'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'gas', self.watch_addr_str, '2', '--tx-attr={"usage":241,"data":"This is a remark"}'] + + res = Wallet.CommandWallet().execute(args) - self.assertTrue(res) - self.assertEqual(2, len( - res.Attributes)) # By default the script_hash of the transaction sender is added to the TransactionAttribute list, therefore the Attributes length is `count` + 1 + self.assertTrue(res) + self.assertEqual(2, len( + res.Attributes)) # By default the script_hash of the transaction sender is added to the TransactionAttribute list, therefore the Attributes length is `count` + 1 def test_multiple_attributes(self): - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'gas', self.watch_addr_str, '2', '--tx-attr=[{"usage":241,"data":"This is a remark"},{"usage":242,"data":"This is a remark 2"}]'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'gas', self.watch_addr_str, '2', + '--tx-attr=[{"usage":241,"data":"This is a remark"},{"usage":242,"data":"This is a remark 2"}]'] - self.assertTrue(res) - self.assertEqual(3, len(res.Attributes)) + res = Wallet.CommandWallet().execute(args) + + self.assertTrue(res) + self.assertEqual(3, len(res.Attributes)) def test_bad_attributes(self): - with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['send', 'gas', self.watch_addr_str, '2', '--tx-attr=[{"usa:241"data":his is a remark"}]'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['send', 'gas', self.watch_addr_str, '2', '--tx-attr=[{"usa:241"data":his is a remark"}]'] - self.assertTrue(res) - self.assertEqual(1, len(res.Attributes)) + res = Wallet.CommandWallet().execute(args) + + self.assertTrue(res) + self.assertEqual(1, len(res.Attributes)) def test_utils_attr_str(self): @@ -348,8 +382,11 @@ def test_fails_to_sign_tx(self): mock_print.getvalue()) def test_fails_to_relay_tx(self): + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] with patch('neo.Prompt.Commands.Send.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - with patch('neo.Prompt.Commands.Send.NodeLeader.Relay', return_value=False): + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(False)): with patch('sys.stdout', new=StringIO()) as mock_print: PromptData.Wallet = self.GetWallet1(recreate=True) args = ['send', 'gas', self.watch_addr_str, '2'] @@ -358,6 +395,7 @@ def test_fails_to_relay_tx(self): self.assertFalse(res) self.assertIn("Could not relay tx", mock_print.getvalue()) + nodemgr.reset_for_test() def test_could_not_send(self): # mocking traceback module to avoid stacktrace printing during test run @@ -373,48 +411,57 @@ def test_could_not_send(self): self.assertIn("Could not send:", mock_print.getvalue()) def test_sendmany_good_simple(self): - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Send.prompt', - side_effect=["neo", self.watch_addr_str, "1", "gas", self.watch_addr_str, "1", UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['sendmany', '2'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - res = Wallet.CommandWallet().execute(args) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Send.prompt', + side_effect=["neo", self.watch_addr_str, "1", "gas", self.watch_addr_str, "1", UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['sendmany', '2'] + res = Wallet.CommandWallet().execute(args) - self.assertTrue(res) # verify successful tx - self.assertIn("Sending with fee: 0", mock_print.getvalue()) - json_res = res.ToJson() + self.assertTrue(res) # verify successful tx + self.assertIn("Sending with fee: 0", mock_print.getvalue()) + json_res = res.ToJson() - # check for 2 transfers - transfers = 0 - for info in json_res['vout']: - if info['address'] == self.watch_addr_str: - transfers += 1 - self.assertEqual(2, transfers) + # check for 2 transfers + transfers = 0 + for info in json_res['vout']: + if info['address'] == self.watch_addr_str: + transfers += 1 + self.assertEqual(2, transfers) def test_sendmany_good_complex(self): - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Send.prompt', - side_effect=["neo", "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", "1", "gas", "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", "1", - UserWalletTestCase.wallet_1_pass()]): - PromptData.Wallet = self.GetWallet1(recreate=True) - args = ['sendmany', '2', '--from-addr=%s' % self.wallet_1_addr, '--change-addr=%s' % self.watch_addr_str, '--fee=0.005'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - address_from_account_state = Blockchain.Default().GetAccountState(self.wallet_1_addr).ToJson() - address_from_gas = next(filter(lambda b: b['asset'] == '0x602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7', - address_from_account_state['balances'])) - address_from_gas_bal = address_from_gas['value'] + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Send.prompt', + side_effect=["neo", "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", "1", "gas", "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", "1", + UserWalletTestCase.wallet_1_pass()]): + PromptData.Wallet = self.GetWallet1(recreate=True) + args = ['sendmany', '2', '--from-addr=%s' % self.wallet_1_addr, '--change-addr=%s' % self.watch_addr_str, '--fee=0.005'] - res = Wallet.CommandWallet().execute(args) + address_from_account_state = Blockchain.Default().GetAccountState(self.wallet_1_addr).ToJson() + address_from_gas = next(filter(lambda b: b['asset'] == '0x602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7', + address_from_account_state['balances'])) + address_from_gas_bal = address_from_gas['value'] + + res = Wallet.CommandWallet().execute(args) - self.assertTrue(res) # verify successful tx + self.assertTrue(res) # verify successful tx - json_res = res.ToJson() - self.assertEqual("AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", json_res['vout'][0]['address']) # verify correct address_to - self.assertEqual(self.watch_addr_str, json_res['vout'][2]['address']) # verify correct change address - self.assertEqual(float(address_from_gas_bal) - 1 - 0.005, float(json_res['vout'][3]['value'])) - self.assertEqual('0.005', json_res['net_fee']) - self.assertIn("Sending with fee: 0.005", mock_print.getvalue()) + json_res = res.ToJson() + self.assertEqual("AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK", json_res['vout'][0]['address']) # verify correct address_to + self.assertEqual(self.watch_addr_str, json_res['vout'][2]['address']) # verify correct change address + self.assertEqual(float(address_from_gas_bal) - 1 - 0.005, float(json_res['vout'][3]['value'])) + self.assertEqual('0.005', json_res['net_fee']) + self.assertIn("Sending with fee: 0.005", mock_print.getvalue()) def test_sendmany_no_wallet(self): with patch('sys.stdout', new=StringIO()) as mock_print: diff --git a/neo/Prompt/Commands/tests/test_show_commands.py b/neo/Prompt/Commands/tests/test_show_commands.py index a753b76e7..252db48a6 100644 --- a/neo/Prompt/Commands/tests/test_show_commands.py +++ b/neo/Prompt/Commands/tests/test_show_commands.py @@ -5,11 +5,13 @@ from neo.Prompt.Commands.Wallet import CommandWallet from neo.Prompt.PromptData import PromptData from neo.bin.prompt import PromptInterface -from neo.Network.NodeLeader import NodeLeader, NeoNode from neo.Core.Blockchain import Blockchain from neo.Implementations.Wallets.peewee.UserWallet import UserWallet -from mock import patch -from neo.Network.address import Address +from mock import mock, patch, MagicMock +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode +from neo.Network.common.singleton import Singleton +from io import StringIO class CommandShowTestCase(BlockchainFixtureTestCase): @@ -122,48 +124,109 @@ def test_show_mem(self): self.assertTrue(res) def test_show_nodes(self): - # query nodes with no NodeLeader.Instance() - with patch('neo.Network.NodeLeader.NodeLeader.Instance'): - args = ['nodes'] + nodemgr = NodeManager() + nodemgr.reset_for_test() + + # test "nodes" with no nodes connected + args = ['nodes'] + with patch('sys.stdout', new=StringIO()) as mock_print: res = CommandShow().execute(args) self.assertFalse(res) + self.assertIn('No nodes connected yet', mock_print.getvalue()) + + # test "nodes verbose" with no nodes connected + args = ['nodes', 'verbose'] + res = CommandShow().execute(args) + self.assertIn('Addresses in queue: 0', res) + self.assertIn('Known addresses: 0', res) + self.assertIn('Bad addresses: 0', res) + + # test "nodes queued" with no nodes connected + args = ['nodes', 'queued'] + res = CommandShow().execute(args) + self.assertIn('No queued addresses', res) + + # test "nodes known" with no nodes connected + args = ['nodes', 'known'] + res = CommandShow().execute(args) + self.assertIn('No known addresses other than connect peers', res) + + # test "nodes bad" with no nodes connected + args = ['nodes', 'bad'] + res = CommandShow().execute(args) + self.assertIn('No bad addresses', res) # query nodes with connected peers # first make sure we have a predictable state - NodeLeader.Instance().Reset() - leader = NodeLeader.Instance() - addr1 = Address("127.0.0.1:20333") - addr2 = Address("127.0.0.1:20334") - leader.ADDRS = [addr1, addr2] - leader.DEAD_ADDRS = [Address("127.0.0.1:20335")] - test_node = NeoNode() - test_node.host = "127.0.0.1" - test_node.port = 20333 - test_node.address = Address("127.0.0.1:20333") - leader.Peers = [test_node] - - # now show nodes - with patch('neo.Network.NeoNode.NeoNode.Name', return_value="test name"): - args = ['nodes'] - res = CommandShow().execute(args) - self.assertTrue(res) - self.assertIn('Total Connected: 1', res) - self.assertIn('Peer 0', res) + node1 = NeoNode(object, object) + node2 = NeoNode(object, object) + node1.address = "127.0.0.1:20333" + node2.address = "127.0.0.1:20334" + node1.best_height = 1025 + node2.best_height = 1026 + node1.version = MagicMock() + node2.version = MagicMock() + node1.version.user_agent = "test_user_agent" + node2.version.user_agent = "test_user_agent" - # now use "node" - args = ['node'] - res = CommandShow().execute(args) - self.assertTrue(res) - self.assertIn('Total Connected: 1', res) - self.assertIn('Peer 0', res) + nodemgr.nodes = [node1, node2] + + queued_address = "127.0.0.1:20335" + known_address = "127.0.0.1:20336" + bad_address = "127.0.0.1:20337" + + nodemgr.queued_addresses.append(queued_address) + nodemgr.known_addresses.append(known_address) + nodemgr.bad_addresses.append(bad_address) - def test_show_state(self): + # now use "node" + args = ['node'] + res = CommandShow().execute(args) + self.assertIn("Connected: 2", res) + self.assertIn("Peer 1", res) + self.assertIn("1025", res) + + # test "nodes verbose" with queued, known, and bad addresses + args = ['nodes', 'verbose'] + res = CommandShow().execute(args) + self.assertIn("Addresses in queue: 1", res) + self.assertIn("Known addresses: 1", res) + self.assertIn("Bad addresses: 1", res) + + # test "nodes queued" with queued, known, and bad addresses + args = ['nodes', 'queued'] + res = CommandShow().execute(args) + self.assertIn("Queued addresses:", res) + self.assertIn(queued_address, res) + + # test "nodes known" with queued, known, and bad addresses + args = ['nodes', 'known'] + res = CommandShow().execute(args) + self.assertIn("Known addresses:", res) + self.assertIn(known_address, res) + + # test "nodes bad" with queued, known, and bad addresses + args = ['nodes', 'bad'] + res = CommandShow().execute(args) + self.assertIn("Bad addresses:", res) + self.assertIn(bad_address, res) + + nodemgr.reset_for_test() + + @mock.patch('neo.Prompt.Commands.Show.SyncManager') + def test_show_state(self, mock_SyncManager): # setup + class mock_SM(Singleton): + def init(self): + self.block_cache = [1, 2, 3, 4, 5] # simulate blocks in the block_cache + mock_SyncManager.return_value = mock_SM() PromptInterface() - args = ['state'] - res = CommandShow().execute(args) - self.assertTrue(res) + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['state'] + res = CommandShow().execute(args) + self.assertTrue(res) + self.assertIn("Block-cache length 5", mock_print.getvalue()) def test_show_notifications(self): # setup @@ -264,10 +327,12 @@ def test_show_account(self): # test empty account with patch('neo.Prompt.PromptData.PromptData.Prompt'): with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=["testpassword", "testpassword"]): - args = ['create', 'testwallet.wallet'] - res = CommandWallet().execute(args) - self.assertTrue(res) - self.assertIsInstance(res, UserWallet) + with patch('neo.Prompt.Commands.Wallet.asyncio'): + with patch('neo.Wallets.Wallet.Wallet.sync_wallet'): + args = ['create', 'testwallet.wallet'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIsInstance(res, UserWallet) addr = res.Addresses[0] args = ['account', addr] diff --git a/neo/Prompt/Commands/tests/test_token_commands.py b/neo/Prompt/Commands/tests/test_token_commands.py index 53cf543aa..6151a781b 100644 --- a/neo/Prompt/Commands/tests/test_token_commands.py +++ b/neo/Prompt/Commands/tests/test_token_commands.py @@ -15,9 +15,11 @@ from mock import patch from neo.Prompt.PromptData import PromptData from contextlib import contextmanager -from io import StringIO, TextIOWrapper +from io import StringIO from neo.VM.InteropService import StackItem from neo.Prompt.PromptPrinter import pp +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode class UserWalletTestCase(WalletFixtureTestCase): @@ -126,50 +128,65 @@ def test_token_balance(self): self.assertEqual(balance, 2499000) def test_token_send_good(self): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - wallet = self.GetWallet1(recreate=True) - token = self.get_token(wallet) - addr_from = wallet.GetDefaultContract().Address - addr_to = self.watch_addr_str - fee = Fixed8.FromDecimal(0.001) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + wallet = self.GetWallet1(recreate=True) + token = self.get_token(wallet) + addr_from = wallet.GetDefaultContract().Address + addr_to = self.watch_addr_str + fee = Fixed8.FromDecimal(0.001) - send = token_send(wallet, token.symbol, addr_from, addr_to, 1300, fee) + send = token_send(wallet, token.symbol, addr_from, addr_to, 1300, fee) - self.assertTrue(send) - res = send.ToJson() - self.assertEqual(res["vout"][0]["address"], "AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3") - self.assertEqual(res["net_fee"], "0.0011") + self.assertTrue(send) + res = send.ToJson() + self.assertEqual(res["vout"][0]["address"], "AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3") + self.assertEqual(res["net_fee"], "0.0011") def test_token_send_with_user_attributes(self): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - wallet = self.GetWallet1(recreate=True) - token = self.get_token(wallet) - addr_from = wallet.GetDefaultContract().Address - addr_to = self.watch_addr_str - _, attributes = get_tx_attr_from_args(['--tx-attr=[{"usage":241,"data":"This is a remark"},{"usage":242,"data":"This is a remark 2"}]']) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - send = token_send(wallet, token.symbol, addr_from, addr_to, 1300, user_tx_attributes=attributes) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + wallet = self.GetWallet1(recreate=True) + token = self.get_token(wallet) + addr_from = wallet.GetDefaultContract().Address + addr_to = self.watch_addr_str + _, attributes = get_tx_attr_from_args(['--tx-attr=[{"usage":241,"data":"This is a remark"},{"usage":242,"data":"This is a remark 2"}]']) - self.assertTrue(send) - res = send.ToJson() - self.assertEqual(len(res['attributes']), 3) - self.assertEqual(res['attributes'][0]['usage'], 241) - self.assertEqual(res['attributes'][1]['usage'], 242) + send = token_send(wallet, token.symbol, addr_from, addr_to, 1300, user_tx_attributes=attributes) + + self.assertTrue(send) + res = send.ToJson() + self.assertEqual(len(res['attributes']), 3) + self.assertEqual(res['attributes'][0]['usage'], 241) + self.assertEqual(res['attributes'][1]['usage'], 242) def test_token_send_bad_user_attributes(self): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - wallet = self.GetWallet1(recreate=True) - token = self.get_token(wallet) - addr_from = wallet.GetDefaultContract().Address - addr_to = self.watch_addr_str + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] - _, attributes = get_tx_attr_from_args(['--tx-attr=[{"usa:241,"data":"This is a remark"}]']) - send = token_send(wallet, token.symbol, addr_from, addr_to, 100, user_tx_attributes=attributes) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + wallet = self.GetWallet1(recreate=True) + token = self.get_token(wallet) + addr_from = wallet.GetDefaultContract().Address + addr_to = self.watch_addr_str - self.assertTrue(send) - res = send.ToJson() - self.assertEqual(1, len(res['attributes'])) - self.assertNotEqual(241, res['attributes'][0]['usage']) + _, attributes = get_tx_attr_from_args(['--tx-attr=[{"usa:241,"data":"This is a remark"}]']) + send = token_send(wallet, token.symbol, addr_from, addr_to, 100, user_tx_attributes=attributes) + + self.assertTrue(send) + res = send.ToJson() + self.assertEqual(1, len(res['attributes'])) + self.assertNotEqual(241, res['attributes'][0]['usage']) def test_token_send_bad_args(self): # too few args wallet = self.GetWallet1(recreate=True) @@ -260,20 +277,25 @@ def test_token_allowance_no_tx(self): self.assertIn("Could not get allowance", str(context.exception)) def test_token_mint_good(self): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): - wallet = self.GetWallet1(recreate=True) - token = self.get_token(wallet) - addr_to = self.wallet_1_addr - asset_attachments = ['--attach-neo=10'] - _, tx_attr = PromptUtils.get_tx_attr_from_args(['--tx-attr={"usage":241,"data":"This is a remark"}']) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[UserWalletTestCase.wallet_1_pass()]): + wallet = self.GetWallet1(recreate=True) + token = self.get_token(wallet) + addr_to = self.wallet_1_addr + asset_attachments = ['--attach-neo=10'] + _, tx_attr = PromptUtils.get_tx_attr_from_args(['--tx-attr={"usage":241,"data":"This is a remark"}']) - mint = token_mint(token, wallet, addr_to, asset_attachments=asset_attachments, invoke_attrs=tx_attr) + mint = token_mint(token, wallet, addr_to, asset_attachments=asset_attachments, invoke_attrs=tx_attr) - self.assertTrue(mint) - res = mint.ToJson() - self.assertEqual(res['attributes'][1]['usage'], 241) # verifies attached attribute - self.assertEqual(res['vout'][0]['value'], "10") # verifies attached neo - self.assertEqual(res['vout'][0]['address'], "Ab61S1rk2VtCVd3NtGNphmBckWk4cfBdmB") # verifies attached neo sent to token contract owner + self.assertTrue(mint) + res = mint.ToJson() + self.assertEqual(res['attributes'][1]['usage'], 241) # verifies attached attribute + self.assertEqual(res['vout'][0]['value'], "10") # verifies attached neo + self.assertEqual(res['vout'][0]['address'], "Ab61S1rk2VtCVd3NtGNphmBckWk4cfBdmB") # verifies attached neo sent to token contract owner def test_token_mint_no_tx(self): with patch('neo.Wallets.NEP5Token.NEP5Token.Mint', return_value=(None, 0, None)): @@ -614,12 +636,17 @@ def test_wallet_token_approve(self): self.assertIn("Failed to approve tokens", mock_print.getvalue()) # test successful approval - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): - args = ['token', 'approve', 'NXT4', addr_from, addr_to, '123', '--fee=0.001'] - res = CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): + args = ['token', 'approve', 'NXT4', addr_from, addr_to, '123', '--fee=0.001'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) def test_wallet_token_allowance(self): with self.OpenWallet1(): @@ -766,13 +793,19 @@ def test_token_mint(self): self.assertIn("Token mint cancelled", mock_print.getvalue()) # test working minting - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): - with patch('sys.stdout', new=StringIO()) as mock_print: - args = ['token', 'mint', 'NXT4', 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y', '--fee=0.001', '--tx-attr={"usage":241,"data":"This is a remark"}'] - res = CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("[NXT4] Will mint tokens to address", mock_print.getvalue()) - self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['token', 'mint', 'NXT4', 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y', '--fee=0.001', + '--tx-attr={"usage":241,"data":"This is a remark"}'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("[NXT4] Will mint tokens to address", mock_print.getvalue()) + self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) def test_token_register(self): with self.OpenWallet1(): @@ -844,13 +877,18 @@ def test_token_register(self): self.assertIn("Registration cancelled", mock_print.getvalue()) # test with valid address - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): - args = ['token', 'register', 'NXT4', 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y', '--fee=0.001'] - res = CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("[NXT4] Will register addresses", mock_print.getvalue()) - self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): + args = ['token', 'register', 'NXT4', 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y', '--fee=0.001'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("[NXT4] Will register addresses", mock_print.getvalue()) + self.assertIn("Priority Fee (0.001) + Invocation Fee (0.0001) = 0.0011", mock_print.getvalue()) # utility function def Approve_Allowance(self): @@ -1027,14 +1065,19 @@ def test_wallet_token_sendfrom(self): self.assertIn("Insufficient allowance", mock_print.getvalue()) # successful test - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Tokens.token_get_allowance', return_value=12300000000): - with patch('neo.Wallets.NEP5Token.NEP5Token.TransferFrom', return_value=self.Approve_Allowance(PromptData.Wallet, token)): - with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): - args = ['token', 'sendfrom', 'NXT4', addr_from, addr_to, '123', '--fee=0.001'] - res = CommandWallet().execute(args) - self.assertTrue(res) - self.assertIn("Priority Fee (0.001) + Transfer Fee (0.0001) = 0.0011", mock_print.getvalue()) + nodemgr = NodeManager() + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object, object)] + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + with patch('sys.stdout', new=StringIO()) as mock_print: + with patch('neo.Prompt.Commands.Tokens.token_get_allowance', return_value=12300000000): + with patch('neo.Wallets.NEP5Token.NEP5Token.TransferFrom', return_value=self.Approve_Allowance(PromptData.Wallet, token)): + with patch('neo.Prompt.Commands.Tokens.prompt', side_effect=[self.wallet_1_pass()]): + args = ['token', 'sendfrom', 'NXT4', addr_from, addr_to, '123', '--fee=0.001'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("Priority Fee (0.001) + Transfer Fee (0.0001) = 0.0011", mock_print.getvalue()) def Approve_Allowance(self, wallet, token): approve_from = self.wallet_1_addr diff --git a/neo/Prompt/Commands/tests/test_wallet_commands.py b/neo/Prompt/Commands/tests/test_wallet_commands.py index d564c8159..96415547d 100644 --- a/neo/Prompt/Commands/tests/test_wallet_commands.py +++ b/neo/Prompt/Commands/tests/test_wallet_commands.py @@ -7,9 +7,11 @@ from neo.Prompt.Commands.Wallet import CommandWallet from neo.Prompt.Commands.Wallet import ShowUnspentCoins from neo.Prompt.PromptData import PromptData -from neo.Prompt.PromptPrinter import pp +from neo.Network.nodemanager import NodeManager +from neo.Network.node import NeoNode import os import shutil +import asyncio from mock import patch from io import StringIO @@ -92,14 +94,15 @@ def remove_new_wallet(): with patch('neo.Prompt.PromptData.PromptData.Prompt'): with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=["testpassword", "testpassword"]): - # test wallet create successful - path = UserWalletTestCase.new_wallet_dest() - args = ['create', path] - self.assertFalse(os.path.isfile(path)) - res = CommandWallet().execute(args) - self.assertEqual(type(res), UserWallet) - self.assertTrue(os.path.isfile(path)) - remove_new_wallet() + with patch('neo.Prompt.Commands.Wallet.asyncio'): + # test wallet create successful + path = UserWalletTestCase.new_wallet_dest() + args = ['create', path] + self.assertFalse(os.path.isfile(path)) + res = CommandWallet().execute(args) + self.assertEqual(type(res), UserWallet) + self.assertTrue(os.path.isfile(path)) + remove_new_wallet() # test wallet create with no path with patch('sys.stdout', new=StringIO()) as mock_print: @@ -111,18 +114,19 @@ def remove_new_wallet(): # test wallet open with already existing path with patch('sys.stdout', new=StringIO()) as mock_print: with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=["testpassword", "testpassword"]): - path = UserWalletTestCase.new_wallet_dest() - args = ['create', path] - self.assertFalse(os.path.isfile(path)) - res = CommandWallet().execute(args) - self.assertEqual(type(res), UserWallet) - self.assertTrue(os.path.isfile(path)) + with patch('neo.Prompt.Commands.Wallet.asyncio'): + path = UserWalletTestCase.new_wallet_dest() + args = ['create', path] + self.assertFalse(os.path.isfile(path)) + res = CommandWallet().execute(args) + self.assertEqual(type(res), UserWallet) + self.assertTrue(os.path.isfile(path)) - res = CommandWallet().execute(args) - self.assertFalse(res) - self.assertTrue(os.path.isfile(path)) - self.assertIn("File already exists", mock_print.getvalue()) - remove_new_wallet() + res = CommandWallet().execute(args) + self.assertFalse(res) + self.assertTrue(os.path.isfile(path)) + self.assertIn("File already exists", mock_print.getvalue()) + remove_new_wallet() # test wallet with different passwords with patch('sys.stdout', new=StringIO()) as mock_print: @@ -180,84 +184,64 @@ def remove_new_wallet(): remove_new_wallet() def test_wallet_open(self): - with patch('neo.Prompt.PromptData.PromptData.Prompt'): - with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[self.wallet_1_pass()]): + loop = asyncio.get_event_loop() + + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[self.wallet_1_pass()]): + async def run_test(): if self._wallet1 is None: shutil.copyfile(self.wallet_1_path(), self.wallet_1_dest()) # test wallet open successful args = ['open', self.wallet_1_dest()] - res = CommandWallet().execute(args) - self.assertEqual(type(res), UserWallet) - # test wallet open with no path; this will also close the open wallet - with patch('sys.stdout', new=StringIO()) as mock_print: + # test wallet open with no path; this will also close the open wallet args = ['open'] - res = CommandWallet().execute(args) - self.assertFalse(res) - self.assertIn("Please specify the required parameter", mock_print.getvalue()) - # test wallet open with bad path - with patch('sys.stdout', new=StringIO()) as mock_print: + # test wallet open with bad path args = ['open', 'badpath'] - res = CommandWallet().execute(args) - self.assertFalse(res) - self.assertIn("Wallet file not found", mock_print.getvalue()) + + loop.run_until_complete(run_test()) # test wallet open unsuccessful - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=["testpassword"]): - with patch('neo.Implementations.Wallets.peewee.UserWallet.UserWallet.Open', side_effect=[Exception('test exception')]): + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=["testpassword"]): + with patch('neo.Implementations.Wallets.peewee.UserWallet.UserWallet.Open', side_effect=[Exception('test exception')]): + async def run_test(): args = ['open', 'fixtures/testwallet.db3'] - res = CommandWallet().execute(args) - self.assertFalse(res) - self.assertIn("Could not open wallet", mock_print.getvalue()) - # test wallet open with keyboard interrupt - with patch('sys.stdout', new=StringIO()) as mock_print: - with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[KeyboardInterrupt]): - args = ['open', self.wallet_1_dest()] - - res = CommandWallet().execute(args) - - self.assertFalse(res) - self.assertIn("Wallet opening cancelled", mock_print.getvalue()) + loop.run_until_complete(run_test()) def test_wallet_close(self): - with patch('neo.Prompt.PromptData.PromptData.Prompt'): - # test wallet close with no wallet - args = ['close'] - - res = CommandWallet().execute(args) - - self.assertFalse(res) + loop = asyncio.get_event_loop() + # test wallet close with no wallet + args = ['close'] + res = CommandWallet().execute(args) + self.assertFalse(res) - # test wallet close with open wallet - with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[self.wallet_1_pass()]): + # test wallet close with open wallet + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[self.wallet_1_pass()]): + async def run_test(): if self._wallet1 is None: shutil.copyfile(self.wallet_1_path(), self.wallet_1_dest()) args = ['open', self.wallet_1_dest()] - res = CommandWallet().execute(args) - self.assertEqual(type(res), UserWallet) # now close the open wallet manually args = ['close'] - res = CommandWallet().execute(args) - self.assertTrue(res) + loop.run_until_complete(run_test()) + def test_wallet_verbose(self): # test wallet verbose with no wallet opened args = ['verbose'] @@ -266,10 +250,37 @@ def test_wallet_verbose(self): self.OpenWallet1() - # test wallet close with open wallet - args = ['verbose'] - res = CommandWallet().execute(args) - self.assertTrue(res) + # first test normal wallet printing + with patch('sys.stdout', new=StringIO()) as mock_print: + args = [''] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertNotIn("Script hash", mock_print.getvalue()) + self.assertNotIn("Public key", mock_print.getvalue()) + + # now test wallet verbose with open wallet + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['verbose'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("Script hash", mock_print.getvalue()) + self.assertIn("Public key", mock_print.getvalue()) + + # also test "v" + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['v'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("Script hash", mock_print.getvalue()) + self.assertIn("Public key", mock_print.getvalue()) + + # and "--v" + with patch('sys.stdout', new=StringIO()) as mock_print: + args = ['--v'] + res = CommandWallet().execute(args) + self.assertTrue(res) + self.assertIn("Script hash", mock_print.getvalue()) + self.assertIn("Public key", mock_print.getvalue()) def test_wallet_claim_1(self): # test with no wallet @@ -291,14 +302,18 @@ def test_wallet_claim_1(self): self.assertIn("Incorrect password", mock_print.getvalue()) # test successful + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[WalletFixtureTestCase.wallet_1_pass()]): - args = ['claim'] - claim_tx, relayed = CommandWallet().execute(args) - self.assertIsInstance(claim_tx, ClaimTransaction) - self.assertTrue(relayed) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['claim'] + claim_tx, relayed = CommandWallet().execute(args) + self.assertIsInstance(claim_tx, ClaimTransaction) + self.assertTrue(relayed) - json_tx = claim_tx.ToJson() - self.assertEqual(json_tx['vout'][0]['address'], self.wallet_1_addr) + json_tx = claim_tx.ToJson() + self.assertEqual(json_tx['vout'][0]['address'], self.wallet_1_addr) + nodemgr.reset_for_test() # test nothing to claim anymore with patch('sys.stdout', new=StringIO()) as mock_print: @@ -331,14 +346,19 @@ def test_wallet_claim_2(self): self.assertIn("Address format error", mock_print.getvalue()) # successful test with --from-addr + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[WalletFixtureTestCase.wallet_2_pass()]): - args = ['claim', '--from-addr=' + self.wallet_1_addr] - claim_tx, relayed = CommandWallet().execute(args) - self.assertIsInstance(claim_tx, ClaimTransaction) - self.assertTrue(relayed) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['claim', '--from-addr=' + self.wallet_1_addr] + claim_tx, relayed = CommandWallet().execute(args) + self.assertIsInstance(claim_tx, ClaimTransaction) + self.assertTrue(relayed) - json_tx = claim_tx.ToJson() - self.assertEqual(json_tx['vout'][0]['address'], self.wallet_1_addr) + json_tx = claim_tx.ToJson() + self.assertEqual(json_tx['vout'][0]['address'], self.wallet_1_addr) + nodemgr.reset_for_test() def test_wallet_claim_3(self): self.OpenWallet1() @@ -362,47 +382,67 @@ def test_wallet_claim_3(self): self.assertIn("Not correct Address, wrong length", mock_print.getvalue()) # test with --to-addr + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[WalletFixtureTestCase.wallet_1_pass()]): - args = ['claim', '--to-addr=' + self.watch_addr_str] - claim_tx, relayed = CommandWallet().execute(args) - self.assertIsInstance(claim_tx, ClaimTransaction) - self.assertTrue(relayed) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['claim', '--to-addr=' + self.watch_addr_str] + claim_tx, relayed = CommandWallet().execute(args) + self.assertIsInstance(claim_tx, ClaimTransaction) + self.assertTrue(relayed) - json_tx = claim_tx.ToJson() - self.assertEqual(json_tx['vout'][0]['address'], self.watch_addr_str) # note how the --to-addr supercedes the default change address + json_tx = claim_tx.ToJson() + self.assertEqual(json_tx['vout'][0]['address'], self.watch_addr_str) # note how the --to-addr supercedes the default change address + nodemgr.reset_for_test() def test_wallet_claim_4(self): self.OpenWallet2() # test with --from-addr and --to-addr + nodemgr = NodeManager() + nodemgr.nodes = [NeoNode(object, object)] + with patch('neo.Prompt.Commands.Wallet.prompt', side_effect=[WalletFixtureTestCase.wallet_2_pass()]): - args = ['claim', '--from-addr=' + self.wallet_1_addr, '--to-addr=' + self.wallet_2_addr] - claim_tx, relayed = CommandWallet().execute(args) - self.assertIsInstance(claim_tx, ClaimTransaction) - self.assertTrue(relayed) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + args = ['claim', '--from-addr=' + self.wallet_1_addr, '--to-addr=' + self.wallet_2_addr] + claim_tx, relayed = CommandWallet().execute(args) + self.assertIsInstance(claim_tx, ClaimTransaction) + self.assertTrue(relayed) - json_tx = claim_tx.ToJson() - self.assertEqual(json_tx['vout'][0]['address'], self.wallet_2_addr) # note how the --to-addr also supercedes the from address if both are specified + json_tx = claim_tx.ToJson() + self.assertEqual(json_tx['vout'][0]['address'], + self.wallet_2_addr) # note how the --to-addr also supercedes the from address if both are specified + nodemgr.reset_for_test() def test_wallet_rebuild(self): - with patch('neo.Prompt.PromptData.PromptData.Prompt'): - # test wallet rebuild with no wallet open - args = ['rebuild'] - res = CommandWallet().execute(args) - self.assertFalse(res) + with patch('neo.Wallets.Wallet.Wallet.sync_wallet', new_callable=self.new_async_mock) as mocked_sync_wallet: + loop = asyncio.get_event_loop() + + async def run_test(): + # test wallet rebuild with no wallet open + args = ['rebuild'] + res = CommandWallet().execute(args) + self.assertFalse(res) + + self.OpenWallet1() + + # test wallet rebuild with no argument + args = ['rebuild'] + task = CommandWallet().execute(args) - self.OpenWallet1() - PromptData.Wallet._current_height = 12345 + # "rebuild" creates a task to start syncing + # we have to wait for it to have started before we can assert the call status + await asyncio.gather(task) + mocked_sync_wallet.assert_called_with(0, rebuild=True) - # test wallet rebuild with no argument - args = ['rebuild'] - CommandWallet().execute(args) - self.assertEqual(PromptData.Wallet._current_height, 0) + # test wallet rebuild with start block + args = ['rebuild', '42'] + task = CommandWallet().execute(args) + await asyncio.gather(task) + mocked_sync_wallet.assert_called_with(42, rebuild=True) - # test wallet rebuild with start block - args = ['rebuild', '42'] - CommandWallet().execute(args) - self.assertEqual(PromptData.Wallet._current_height, 42) + loop.run_until_complete(run_test()) def test_wallet_unspent(self): # test wallet unspent with no wallet open diff --git a/neo/Prompt/PromptData.py b/neo/Prompt/PromptData.py index 2bb5a6cdc..b123e7ce3 100644 --- a/neo/Prompt/PromptData.py +++ b/neo/Prompt/PromptData.py @@ -1,3 +1,6 @@ +from neo.Core.Blockchain import Blockchain + + class PromptData: Prompt = None Wallet = None @@ -8,7 +11,7 @@ def close_wallet(): return False path = PromptData.Wallet._path - PromptData.Prompt.stop_wallet_loop() + Blockchain.Default().PersistCompleted.on_change -= PromptData.Wallet.ProcessNewBlock PromptData.Wallet.Close() PromptData.Wallet = None print("Closed wallet %s" % path) diff --git a/neo/Prompt/Utils.py b/neo/Prompt/Utils.py index 2c8a914cd..8e30035a2 100644 --- a/neo/Prompt/Utils.py +++ b/neo/Prompt/Utils.py @@ -8,11 +8,11 @@ from neo.SmartContract.ContractParameter import ContractParameterType from neo.Core.Cryptography.ECCurve import ECDSA from decimal import Decimal -from prompt_toolkit.shortcuts import PromptSession from neo.logging import log_manager from neo.Wallets import NEP5Token from neo.Core.Cryptography.Crypto import Crypto from typing import TYPE_CHECKING +from neo.Network.common import blocking_prompt as prompt if TYPE_CHECKING: from neo.Wallets.Wallet import Wallet @@ -309,10 +309,7 @@ def string_from_fixed8(amount, decimals): def get_input_prompt(message): - from neo.bin.prompt import PromptInterface - - return PromptSession(completer=PromptInterface.prompt_completer, - history=PromptInterface.history).prompt(message) + return prompt(message) def gather_param(index, param_type, do_continue=True): diff --git a/neo/Prompt/vm_debugger.py b/neo/Prompt/vm_debugger.py index 029ffa085..c226c1282 100644 --- a/neo/Prompt/vm_debugger.py +++ b/neo/Prompt/vm_debugger.py @@ -1,4 +1,4 @@ -from prompt_toolkit import prompt +from neo.Network.common import blocking_prompt as prompt from neo.Prompt.InputParser import InputParser from neo.SmartContract.ContractParameter import ContractParameter from neo.SmartContract.ContractParameterType import ContractParameterType @@ -180,7 +180,12 @@ def start(self): value = self.engine.AltStack.Items[-1].GetArray()[idx] param = ContractParameter.ToParameter(value) print("\n") - print('%s = %s [%s]' % (command, json.dumps(param.Value.ToJson(), indent=4) if param.Type == ContractParameterType.InteropInterface else param.Value, param.Type)) + + if param.Type == ContractParameterType.InteropInterface: + cmd_value = json.dumps(param.Value.ToJson(), indent=4) + else: + cmd_value = param.Value + print(f"{command} = {cmd_value} [{param.Type}]") print("\n") except Exception as e: logger.error("Could not lookup item %s: %s " % (command, e)) diff --git a/neo/Settings.py b/neo/Settings.py index 47cee8ffa..677c88ba2 100644 --- a/neo/Settings.py +++ b/neo/Settings.py @@ -105,7 +105,8 @@ class SettingsHolder: DEBUG_STORAGE_PATH = 'Chains/debugstorage' ACCEPT_INCOMING_PEERS = False - CONNECTED_PEER_MAX = 20 + CONNECTED_PEER_MAX = 10 + CONNECTED_PEER_MIN = 4 SERVICE_ENABLED = True @@ -279,6 +280,13 @@ def set_max_peers(self, num_peers): else: raise ValueError + def set_min_peers(self, num_peers): + minpeers = int(num_peers) + if minpeers > 0: + self.CONNECTED_PEER_MIN = minpeers + else: + raise ValueError + def set_log_smart_contract_events(self, is_enabled=True): self.log_smart_contract_events = is_enabled diff --git a/neo/SmartContract/ApplicationEngine.py b/neo/SmartContract/ApplicationEngine.py index 785e43d13..0775415b8 100644 --- a/neo/SmartContract/ApplicationEngine.py +++ b/neo/SmartContract/ApplicationEngine.py @@ -22,51 +22,52 @@ from neo.VM.ExecutionEngine import ExecutionEngine from neo.VM.InteropService import Array from neo.VM.OpCode import APPCALL, TAILCALL, \ - SYSCALL, NOP, SHA256, SHA1, HASH160, HASH256, CHECKSIG, CHECKMULTISIG + SYSCALL, NOP, SHA256, SHA1, HASH160, HASH256, CHECKSIG, CHECKMULTISIG, VERIFY from neo.logging import log_manager logger = log_manager.getLogger('vm') +HASH_NEO_ASSET_CREATE = hash("Neo.Asset.Create") +HASH_ANT_ASSET_CREATE = hash("AntShares.Asset.Create") +HASH_NEO_ASSET_RENEW = hash("Neo.Asset.Renew") +HASH_ANT_ASSET_RENEW = hash("AntShares.Asset.Renew") +HASH_NEO_CONTRACT_CREATE = hash("Neo.Contract.Create") +HASH_NEO_CONTRACT_MIGRATE = hash("Neo.Contract.Migrate") +HASH_ANT_CONTRACT_CREATE = hash("AntShares.Contract.Create") +HASH_ANT_CONTRACT_MIGRATE = hash("AntShares.Contract.Migrate") +HASH_SYSTEM_STORAGE_PUT = hash("System.Storage.Put") +HASH_SYSTEM_STORAGE_PUTEX = hash("System.Storage.PutEx") +HASH_NEO_STORAGE_PUT = hash("Neo.Storage.Put") +HASH_ANT_STORAGE_PUT = hash("AntShares.Storage.Put") + class ApplicationEngine(ExecutionEngine): ratio = 100000 gas_free = 10 * 100000000 - gas_amount = 0 - gas_consumed = 0 - testMode = False - - Trigger = None - - invocation_args = None - max_free_ops = 500000 def GasConsumed(self): return Fixed8(self.gas_consumed) - def __init__(self, trigger_type, container, table, service, gas, testMode=False, exit_on_error=False): + def __init__(self, trigger_type, container, table, service, gas, testMode=False, exit_on_error=True): super(ApplicationEngine, self).__init__(container=container, crypto=Crypto.Default(), table=table, service=service, exit_on_error=exit_on_error) + self.service = service self.Trigger = trigger_type self.gas_amount = self.gas_free + gas.value self.testMode = testMode self._is_stackitem_count_strict = True + self.debugger = None + self.gas_consumed = 0 + self.invocation_args = None - def CheckDynamicInvoke(self, opcode): + def CheckDynamicInvoke(self): cx = self.CurrentContext + opcode = cx.CurrentInstruction.OpCode if opcode in [OpCode.APPCALL, OpCode.TAILCALL]: - opreader = cx.OpReader - # read the current position of the stream - start_pos = opreader.stream.tell() - - # normal app calls are stored in the op reader - # we read ahead past the next instruction 1 the next 20 bytes - script_hash = opreader.ReadBytes(21)[1:] - - # then reset the position - opreader.stream.seek(start_pos) + script_hash = cx.CurrentInstruction.Operand for b in script_hash: # if any of the bytes are greater than 0, this is a normal app call @@ -88,8 +89,8 @@ def CheckDynamicInvoke(self, opcode): else: return True - def PreStepInto(self, opcode): - if self.CurrentContext.InstructionPointer >= len(self.CurrentContext.Script): + def PreExecuteInstruction(self): + if self.CurrentContext.InstructionPointer >= self.CurrentContext.Script.Length: return True self.gas_consumed = self.gas_consumed + (self.GetPrice() * self.ratio) if not self.testMode and self.gas_consumed > self.gas_amount: @@ -98,45 +99,15 @@ def PreStepInto(self, opcode): logger.debug("Too many free operations processed") return False try: - if not self.CheckDynamicInvoke(opcode): + if not self.CheckDynamicInvoke(): return False except Exception: pass return True - # @profile_it - def Execute(self): - try: - if settings.log_vm_instructions: - self.log_file = open(self.log_file_name, 'w') - self.write_log(str(datetime.datetime.now())) - - while True: - if self.CurrentContext.InstructionPointer >= len(self.CurrentContext.Script): - nextOpcode = OpCode.RET - else: - nextOpcode = self.CurrentContext.NextInstruction - - if not self.PreStepInto(nextOpcode): - # TODO: check with NEO is this should now be changed to not use |= - self._VMState |= VMState.FAULT - return False - self.StepInto() - if self._VMState & VMState.HALT > 0 or self._VMState & VMState.FAULT > 0: - break - except Exception: - self._VMState |= VMState.FAULT - return False - finally: - if self.log_file: - self.log_file.close() - - return not self._VMState & VMState.FAULT > 0 - def GetPrice(self): - opcode = self.CurrentContext.NextInstruction - + opcode = self.CurrentContext.CurrentInstruction.OpCode if opcode <= NOP: return 0 @@ -148,7 +119,7 @@ def GetPrice(self): return 10 elif opcode in [HASH160, HASH256]: return 20 - elif opcode == CHECKSIG: + elif opcode in [CHECKSIG, VERIFY]: return 100 elif opcode == CHECKMULTISIG: if self.CurrentContext.EvaluationStack.Count == 0: @@ -169,75 +140,21 @@ def GetPrice(self): return 1 def GetPriceForSysCall(self): + instruction = self.CurrentContext.CurrentInstruction + api_hash = instruction.TokenU32 if len(instruction.Operand) == 4 else hash(instruction.TokenString) - if self.CurrentContext.InstructionPointer >= len(self.CurrentContext.Script) - 3: - return 1 - - length = self.CurrentContext.Script[self.CurrentContext.InstructionPointer + 1] - - if self.CurrentContext.InstructionPointer > len(self.CurrentContext.Script) - length - 2: - return 1 - - strbytes = self.CurrentContext.Script[self.CurrentContext.InstructionPointer + 2:length + self.CurrentContext.InstructionPointer + 2] - - api_name = strbytes.decode('utf-8') + price = self.service.GetPrice(api_hash) - api = api_name.replace('Antshares.', 'Neo.') - api = api.replace('System.', 'Neo.') + if price > 0: + return price - if api == "Neo.Runtime.CheckWitness": - return 200 - - elif api == "Neo.Blockchain.GetHeader": - return 100 - - elif api == "Neo.Blockchain.GetBlock": - return 200 - - elif api == "Neo.Blockchain.GetTransaction": - return 100 - - elif api == "Neo.Blockchain.GetTransactionHeight": - return 100 - - elif api == "Neo.Blockchain.GetAccount": - return 100 - - elif api == "Neo.Blockchain.GetValidators": - return 200 - - elif api == "Neo.Blockchain.GetAsset": - return 100 - - elif api == "Neo.Blockchain.GetContract": - return 100 - - elif api == "Neo.Transaction.GetReferences": - return 200 - - elif api == "Neo.Transaction.GetWitnesses": - return 200 - - elif api == "Neo.Transaction.GetUnspentCoins": - return 200 - - elif api in ["Neo.Witness.GetInvocationScript", "Neo.Witness.GetVerificationScript"]: - return 100 - - elif api == "Neo.Account.SetVotes": - return 1000 - - elif api == "Neo.Validator.Register": - return int(1000 * 100000000 / self.ratio) - - elif api == "Neo.Asset.Create": + if api_hash == HASH_NEO_ASSET_CREATE or api_hash == HASH_ANT_ASSET_CREATE: return int(5000 * 100000000 / self.ratio) - elif api == "Neo.Asset.Renew": + if api_hash == HASH_ANT_ASSET_RENEW or api_hash == HASH_ANT_ASSET_RENEW: return int(self.CurrentContext.EvaluationStack.Peek(1).GetBigInteger() * 5000 * 100000000 / self.ratio) - elif api == "Neo.Contract.Create" or api == "Neo.Contract.Migrate": - + if api_hash == HASH_NEO_CONTRACT_CREATE or api_hash == HASH_NEO_CONTRACT_MIGRATE or api_hash == HASH_ANT_CONTRACT_CREATE or api_hash == HASH_ANT_CONTRACT_MIGRATE: fee = int(100 * 100000000 / self.ratio) # 100 gas for contract with no storage no dynamic invoke contract_properties = self.CurrentContext.EvaluationStack.Peek(3).GetBigInteger() @@ -250,21 +167,15 @@ def GetPriceForSysCall(self): return fee - elif api == "Neo.Storage.Get": - return 100 - - elif api == "Neo.Storage.Put": + if api_hash == HASH_SYSTEM_STORAGE_PUT or api_hash == HASH_SYSTEM_STORAGE_PUTEX or api_hash == HASH_NEO_STORAGE_PUT or api_hash == HASH_ANT_STORAGE_PUT: l1 = len(self.CurrentContext.EvaluationStack.Peek(1).GetByteArray()) l2 = len(self.CurrentContext.EvaluationStack.Peek(2).GetByteArray()) return (int((l1 + l2 - 1) / 1024) + 1) * 1000 - elif api == "Neo.Storage.Delete": - return 100 - return 1 @staticmethod - def Run(script, container=None, exit_on_error=False, gas=Fixed8.Zero(), test_mode=True): + def Run(script, container=None, exit_on_error=True, gas=Fixed8.Zero(), test_mode=True): """ Runs a script in a test invoke environment diff --git a/neo/SmartContract/Contract.py b/neo/SmartContract/Contract.py index 08b247c87..300b1de92 100644 --- a/neo/SmartContract/Contract.py +++ b/neo/SmartContract/Contract.py @@ -24,10 +24,6 @@ class ContractType: class Contract(SerializableMixin, VerificationCode): """docstring for Contract""" - PublicKeyHash = None - - _address = None - @property def Address(self): if self._address is None: diff --git a/neo/SmartContract/ContractParameter.py b/neo/SmartContract/ContractParameter.py index a6a5a0676..1d9fd9788 100644 --- a/neo/SmartContract/ContractParameter.py +++ b/neo/SmartContract/ContractParameter.py @@ -10,9 +10,6 @@ class ContractParameter: """Contract Parameter used for parsing parameters sent to and from smart contract invocations""" - Type = None - Value = None - def __init__(self, type, value): """ diff --git a/neo/SmartContract/ContractParameterContext.py b/neo/SmartContract/ContractParameterContext.py index e27960aa5..88e0ab06b 100755 --- a/neo/SmartContract/ContractParameterContext.py +++ b/neo/SmartContract/ContractParameterContext.py @@ -4,7 +4,7 @@ from neo.SmartContract.Contract import Contract, ContractType from neo.SmartContract.ContractParameterType import ContractParameterType, ToName from neo.VM.ScriptBuilder import ScriptBuilder -from neo.IO.MemoryStream import MemoryStream +from neo.IO.MemoryStream import StreamManager, MemoryStream from neo.Core.IO.BinaryReader import BinaryReader from neo.Core.IO.BinaryWriter import BinaryWriter from neo.VM import OpCode @@ -16,9 +16,6 @@ class ContractParamater: - Type = None - Value = None - def __init__(self, type): if isinstance(type, ContractParameterType): self.Type = type @@ -26,6 +23,7 @@ def __init__(self, type): self.Type = ContractParameterType(type) else: raise Exception("Invalid Contract Parameter Type %s. Must be ContractParameterType or int" % type) + self.Value = None def ToJson(self): jsn = {} @@ -34,13 +32,9 @@ def ToJson(self): class ContextItem: - Script = None - ContractParameters = None - Signatures = None - - IsCustomContract = False def __init__(self, contract): + self.Signatures = None self.Script = contract.Script self.ContractParameters = [] for b in bytearray(contract.ParameterList): @@ -76,14 +70,6 @@ def ToJson(self): class ContractParametersContext: - Verifiable = None - - ScriptHashes = None - - ContextItems = None - - IsMultiSig = None - def __init__(self, verifiable, isMultiSig=False): self.Verifiable = verifiable @@ -162,7 +148,7 @@ def AddSignature(self, contract, pubkey, signature): ecdsa = ECDSA.secp256r1() points = [] temp = binascii.unhexlify(contract.Script) - ms = MemoryStream(binascii.unhexlify(contract.Script)) + ms = StreamManager.GetStream(binascii.unhexlify(contract.Script)) reader = BinaryReader(ms) numr = reader.ReadUInt8() try: @@ -172,7 +158,7 @@ def AddSignature(self, contract, pubkey, signature): except ValueError: return False finally: - ms.close() + StreamManager.ReleaseStream(ms) if pubkey not in points: return False @@ -290,9 +276,10 @@ def FromJson(jsn, isMultiSig=True): parsed = json.loads(jsn) if parsed['type'] == 'Neo.Core.ContractTransaction': verifiable = ContractTransaction() - ms = MemoryStream(binascii.unhexlify(parsed['hex'])) + ms = StreamManager.GetStream(binascii.unhexlify(parsed['hex'])) r = BinaryReader(ms) verifiable.DeserializeUnsigned(r) + StreamManager.ReleaseStream(ms) context = ContractParametersContext(verifiable, isMultiSig=isMultiSig) for key, value in parsed['items'].items(): if "0x" in key: diff --git a/neo/SmartContract/Iterable/ArrayWrapper.py b/neo/SmartContract/Iterable/ArrayWrapper.py new file mode 100644 index 000000000..4098c20e3 --- /dev/null +++ b/neo/SmartContract/Iterable/ArrayWrapper.py @@ -0,0 +1,27 @@ +from neo.SmartContract.Iterable import Iterator + + +class ArrayWrapper(Iterator): + def __init__(self, array): + self.array = array + self.index = -1 + + def Dispose(self): + pass + + def Key(self): + if self.index < 0: + raise ValueError + return self.index + + def Next(self) -> bool: + next = self.index + 1 + if next >= len(self.array): + return False + self.index = next + return True + + def Value(self): + if self.index < 0: + raise ValueError + return self.array[self.index] diff --git a/neo/SmartContract/Iterable/ConcatenatedIterator.py b/neo/SmartContract/Iterable/ConcatenatedIterator.py new file mode 100644 index 000000000..5c5365d23 --- /dev/null +++ b/neo/SmartContract/Iterable/ConcatenatedIterator.py @@ -0,0 +1,34 @@ +from neo.SmartContract.Iterable import Iterator +from neo.SmartContract.Iterable.ArrayWrapper import ArrayWrapper +from neo.VM.InteropService import StackItem + + +class ConcatenatedIterator(Iterator): + def __init__(self, first, second): + if first == second: + new_list = [] + while first.Next(): + new_list.append(first.Value()) + + first = ArrayWrapper(new_list) + second = ArrayWrapper(new_list) + + self.first = first + self.current = self.first + self.second = second + + def Key(self) -> StackItem: + return self.current.Key() + + def Value(self) -> StackItem: + return self.current.Value() + + def Next(self) -> bool: + if self.current.Next(): + return True + + self.current = self.second + return self.current.Next() + + def Dispose(self): + pass diff --git a/neo/SmartContract/Iterable/test_interop_iterable.py b/neo/SmartContract/Iterable/test_interop_iterable.py index 193b55cb7..ca7a76fce 100644 --- a/neo/SmartContract/Iterable/test_interop_iterable.py +++ b/neo/SmartContract/Iterable/test_interop_iterable.py @@ -1,22 +1,21 @@ from unittest import TestCase -from neo.VM.InteropService import Struct, StackItem, Array, Boolean, Map +from neo.VM.InteropService import StackItem, Array, Map from neo.VM.ExecutionEngine import ExecutionEngine from neo.VM.ExecutionEngine import ExecutionContext -from neo.SmartContract.StateReader import StateReader -from neo.SmartContract.Iterable import Iterator, KeysWrapper, ValuesWrapper +from neo.VM.Script import Script +from neo.SmartContract.Iterable import KeysWrapper, ValuesWrapper from neo.SmartContract.Iterable.Wrapper import ArrayWrapper, MapWrapper from neo.SmartContract.Iterable.ConcatenatedEnumerator import ConcatenatedEnumerator +from neo.SmartContract.StateMachine import StateMachine class InteropSerializeDeserializeTestCase(TestCase): - engine = None - econtext = None - state_reader = None - def setUp(self): self.engine = ExecutionEngine() - self.econtext = ExecutionContext(engine=self.engine) - self.state_reader = StateReader() + self.econtext = ExecutionContext(Script(self.engine.Crypto, b''), 0) + self.engine.InvocationStack.PushT(self.econtext) + + self.service = StateMachine(None, None, None, None, None, None) def test_iter_array(self): my_array = Array([StackItem.New(12), @@ -26,7 +25,7 @@ def test_iter_array(self): ]) self.econtext.EvaluationStack.PushT(my_array) self.engine.InvocationStack.PushT(self.econtext) - self.state_reader.Enumerator_Create(self.engine) + self.service.Enumerator_Create(self.engine) iterable = self.econtext.EvaluationStack.Peek(0).GetInterface() @@ -54,7 +53,7 @@ def test_iter_map(self): self.econtext.EvaluationStack.PushT(my_map) self.engine.InvocationStack.PushT(self.econtext) - self.state_reader.Iterator_Create(self.engine) + self.service.Iterator_Create(self.engine) iterable = self.econtext.EvaluationStack.Peek(0).GetInterface() @@ -80,9 +79,9 @@ def test_iter_array_keys(self): ]) self.econtext.EvaluationStack.PushT(my_array) self.engine.InvocationStack.PushT(self.econtext) - self.state_reader.Enumerator_Create(self.engine) + self.service.Enumerator_Create(self.engine) - create_iterkeys = self.state_reader.Iterator_Keys(self.engine) + create_iterkeys = self.service.Iterator_Keys(self.engine) self.assertEqual(create_iterkeys, True) @@ -104,9 +103,9 @@ def test_iter_array_values(self): ]) self.econtext.EvaluationStack.PushT(my_array) self.engine.InvocationStack.PushT(self.econtext) - self.state_reader.Enumerator_Create(self.engine) + self.service.Enumerator_Create(self.engine) - create_itervalues = self.state_reader.Iterator_Values(self.engine) + create_itervalues = self.service.Iterator_Values(self.engine) self.assertEqual(create_itervalues, True) @@ -131,13 +130,13 @@ def test_iter_concat(self): self.econtext.EvaluationStack.PushT(my_array2) self.engine.InvocationStack.PushT(self.econtext) - self.state_reader.Enumerator_Create(self.engine) + self.service.Enumerator_Create(self.engine) self.econtext.EvaluationStack.PushT(my_array) - self.state_reader.Enumerator_Create(self.engine) + self.service.Enumerator_Create(self.engine) - result = self.state_reader.Enumerator_Concat(self.engine) + result = self.service.Enumerator_Concat(self.engine) self.assertEqual(result, True) @@ -160,7 +159,7 @@ def test_iter_array_bad(self): self.econtext.EvaluationStack.PushT(my_item) self.engine.InvocationStack.PushT(self.econtext) - result = self.state_reader.Enumerator_Create(self.engine) + result = self.service.Enumerator_Create(self.engine) self.assertEqual(result, False) self.assertEqual(self.econtext.EvaluationStack.Count, 0) @@ -169,7 +168,7 @@ def test_iter_map_bad(self): my_item = StackItem.New(12) self.econtext.EvaluationStack.PushT(my_item) self.engine.InvocationStack.PushT(self.econtext) - result = self.state_reader.Iterator_Create(self.engine) + result = self.service.Iterator_Create(self.engine) self.assertEqual(result, False) self.assertEqual(self.econtext.EvaluationStack.Count, 0) @@ -179,7 +178,7 @@ def test_iter_array_key_bad(self): self.econtext.EvaluationStack.PushT(my_item) self.engine.InvocationStack.PushT(self.econtext) - result = self.state_reader.Iterator_Key(self.engine) + result = self.service.Iterator_Key(self.engine) self.assertEqual(result, False) self.assertEqual(self.econtext.EvaluationStack.Count, 0) @@ -189,7 +188,7 @@ def test_iter_array_values_bad(self): self.econtext.EvaluationStack.PushT(my_item) self.engine.InvocationStack.PushT(self.econtext) - result = self.state_reader.Iterator_Values(self.engine) + result = self.service.Iterator_Values(self.engine) self.assertEqual(result, False) self.assertEqual(self.econtext.EvaluationStack.Count, 0) @@ -199,7 +198,7 @@ def test_iter_array_keys_bad(self): my_item = StackItem.New(12) self.econtext.EvaluationStack.PushT(my_item) self.engine.InvocationStack.PushT(self.econtext) - result = self.state_reader.Iterator_Keys(self.engine) + result = self.service.Iterator_Keys(self.engine) self.assertEqual(result, False) self.assertEqual(self.econtext.EvaluationStack.Count, 0) diff --git a/neo/SmartContract/LogEventArgs.py b/neo/SmartContract/LogEventArgs.py index b40a23bb0..87ddf01b7 100644 --- a/neo/SmartContract/LogEventArgs.py +++ b/neo/SmartContract/LogEventArgs.py @@ -1,13 +1,5 @@ - - class LogEventArgs: - - ScriptContainer = None - ScriptHash = None - Message = None - def __init__(self, container, script_hash, message): - self.ScriptContainer = container self.ScriptHash = script_hash self.Message = message diff --git a/neo/SmartContract/NotifyEventArgs.py b/neo/SmartContract/NotifyEventArgs.py index 455d62a05..899faeae9 100644 --- a/neo/SmartContract/NotifyEventArgs.py +++ b/neo/SmartContract/NotifyEventArgs.py @@ -1,13 +1,5 @@ - - class NotifyEventArgs: - - ScriptContainer = None - ScriptHash = None - State = None - def __init__(self, container, script_hash, state): - self.ScriptContainer = container self.ScriptHash = script_hash self.State = state diff --git a/neo/SmartContract/SmartContractEvent.py b/neo/SmartContract/SmartContractEvent.py index b828d1f8b..a1252bec3 100644 --- a/neo/SmartContract/SmartContractEvent.py +++ b/neo/SmartContract/SmartContractEvent.py @@ -52,17 +52,6 @@ class SmartContractEvent(SerializableMixin): CONTRACT_MIGRATED = "SmartContract.Contract.Migrate" CONTRACT_DESTROY = "SmartContract.Contract.Destroy" - event_type = None - event_payload = None # type:ContractParameter - contract_hash = None - block_number = None - tx_hash = None - execution_success = None - test_mode = None - - contract = None - token = None - def __init__(self, event_type, event_payload, contract_hash, block_number, tx_hash, execution_success=False, test_mode=False): if event_payload and not isinstance(event_payload, ContractParameter): @@ -76,6 +65,7 @@ def __init__(self, event_type, event_payload, contract_hash, block_number, tx_ha self.execution_success = execution_success self.test_mode = test_mode self.token = None + self.contract = None if not self.event_payload: self.event_payload = ContractParameter(ContractParameterType.Array, value=[]) diff --git a/neo/SmartContract/StateMachine.py b/neo/SmartContract/StateMachine.py index 160b81d39..8b1e263fe 100644 --- a/neo/SmartContract/StateMachine.py +++ b/neo/SmartContract/StateMachine.py @@ -11,10 +11,16 @@ from neo.Core.Cryptography.ECCurve import ECDSA from neo.Core.UInt160 import UInt160 from neo.Core.UInt256 import UInt256 +from neo.Core.State.AccountState import AccountState from neo.Core.Fixed8 import Fixed8 from neo.VM.InteropService import StackItem from neo.VM.ExecutionEngine import ExecutionEngine from neo.SmartContract.StorageContext import StorageContext +from neo.VM.InteropService import StackItem, ByteArray, Array, Map +from neo.SmartContract.Iterable.Wrapper import ArrayWrapper, MapWrapper +from neo.SmartContract.Iterable import KeysWrapper, ValuesWrapper +from neo.SmartContract.Iterable.ConcatenatedEnumerator import ConcatenatedEnumerator +from neo.SmartContract.Iterable.ConcatenatedIterator import ConcatenatedIterator from neo.SmartContract.StateReader import StateReader from neo.SmartContract.ContractParameter import ContractParameter, ContractParameterType from neo.EventHub import SmartContractEvent @@ -24,11 +30,6 @@ class StateMachine(StateReader): - _validators = None - _wb = None - - _contracts_created = {} - def __init__(self, accounts, validators, assets, contracts, storages, wb): super(StateMachine, self).__init__() @@ -39,41 +40,173 @@ def __init__(self, accounts, validators, assets, contracts, storages, wb): self._contracts = contracts self._storages = storages self._wb = wb + self._contracts_created = {} + + # checks for testing purposes + if accounts is not None: + self._accounts.MarkForReset() + if validators is not None: + self._validators.MarkForReset() + if assets is not None: + self._assets.MarkForReset() + if contracts is not None: + self._contracts.MarkForReset() + if storages is not None: + self._storages.MarkForReset() + + self.RegisterWithPrice("Neo.Runtime.GetTrigger", self.Runtime_GetTrigger, 1) + self.RegisterWithPrice("Neo.Runtime.CheckWitness", self.Runtime_CheckWitness, 200) + self.RegisterWithPrice("Neo.Runtime.Notify", self.Runtime_Notify, 1) + self.RegisterWithPrice("Neo.Runtime.Log", self.Runtime_Log, 1) + self.RegisterWithPrice("Neo.Runtime.GetTime", self.Runtime_GetCurrentTime, 1) + self.RegisterWithPrice("Neo.Runtime.Serialize", self.Runtime_Serialize, 1) + self.RegisterWithPrice("Neo.Runtime.Deserialize", self.Runtime_Deserialize, 1) + + self.RegisterWithPrice("Neo.Blockchain.GetHeight", self.Blockchain_GetHeight, 1) + self.RegisterWithPrice("Neo.Blockchain.GetHeader", self.Blockchain_GetHeader, 100) + self.RegisterWithPrice("Neo.Blockchain.GetBlock", self.Blockchain_GetBlock, 200) + self.RegisterWithPrice("Neo.Blockchain.GetTransaction", self.Blockchain_GetTransaction, 100) + self.RegisterWithPrice("Neo.Blockchain.GetTransactionHeight", self.Blockchain_GetTransactionHeight, 100) + self.RegisterWithPrice("Neo.Blockchain.GetAccount", self.Blockchain_GetAccount, 100) + self.RegisterWithPrice("Neo.Blockchain.GetValidators", self.Blockchain_GetValidators, 100) + self.RegisterWithPrice("Neo.Blockchain.GetAsset", self.Blockchain_GetAsset, 100) + self.RegisterWithPrice("Neo.Blockchain.GetContract", self.Blockchain_GetContract, 100) + + self.RegisterWithPrice("Neo.Header.GetHash", self.Header_GetHash, 1) + self.RegisterWithPrice("Neo.Header.GetVersion", self.Header_GetVersion, 1) + self.RegisterWithPrice("Neo.Header.GetPrevHash", self.Header_GetPrevHash, 1) + self.RegisterWithPrice("Neo.Header.GetMerkleRoot", self.Header_GetMerkleRoot, 1) + self.RegisterWithPrice("Neo.Header.GetTimestamp", self.Header_GetTimestamp, 1) + self.RegisterWithPrice("Neo.Header.GetIndex", self.Header_GetIndex, 1) + self.RegisterWithPrice("Neo.Header.GetConsensusData", self.Header_GetConsensusData, 1) + self.RegisterWithPrice("Neo.Header.GetNextConsensus", self.Header_GetNextConsensus, 1) + + self.RegisterWithPrice("Neo.Block.GetTransactionCount", self.Block_GetTransactionCount, 1) + self.RegisterWithPrice("Neo.Block.GetTransactions", self.Block_GetTransactions, 1) + self.RegisterWithPrice("Neo.Block.GetTransaction", self.Block_GetTransaction, 1) + + self.RegisterWithPrice("Neo.Transaction.GetHash", self.Transaction_GetHash, 1) + self.RegisterWithPrice("Neo.Transaction.GetType", self.Transaction_GetType, 1) + self.RegisterWithPrice("Neo.Transaction.GetAttributes", self.Transaction_GetAttributes, 1) + self.RegisterWithPrice("Neo.Transaction.GetInputs", self.Transaction_GetInputs, 1) + self.RegisterWithPrice("Neo.Transaction.GetOutputs", self.Transaction_GetOutputs, 1) + self.RegisterWithPrice("Neo.Transaction.GetReferences", self.Transaction_GetReferences, 200) + self.RegisterWithPrice("Neo.Transaction.GetUnspentCoins", self.Transaction_GetUnspentCoins, 200) + self.RegisterWithPrice("Neo.Transaction.GetWitnesses", self.Transaction_GetWitnesses, 200) + + self.RegisterWithPrice("Neo.InvocationTransaction.GetScript", self.InvocationTransaction_GetScript, 1) + self.RegisterWithPrice("Neo.Witness.GetVerificationScript", self.Witness_GetVerificationScript, 100) + self.RegisterWithPrice("Neo.Attribute.GetUsage", self.Attribute_GetUsage, 1) + self.RegisterWithPrice("Neo.Attribute.GetData", self.Attribute_GetData, 1) + + self.RegisterWithPrice("Neo.Input.GetHash", self.Input_GetHash, 1) + self.RegisterWithPrice("Neo.Input.GetIndex", self.Input_GetIndex, 1) + self.RegisterWithPrice("Neo.Output.GetAssetId", self.Output_GetAssetId, 1) + self.RegisterWithPrice("Neo.Output.GetValue", self.Output_GetValue, 1) + self.RegisterWithPrice("Neo.Output.GetScriptHash", self.Output_GetScriptHash, 1) + + self.RegisterWithPrice("Neo.Account.GetScriptHash", self.Account_GetScriptHash, 1) + self.RegisterWithPrice("Neo.Account.GetVotes", self.Account_GetVotes, 1) + self.RegisterWithPrice("Neo.Account.GetBalance", self.Account_GetBalance, 1) + self.RegisterWithPrice("Neo.Account.IsStandard", self.Account_IsStandard, 100) - self._accounts.MarkForReset() - self._validators.MarkForReset() - self._assets.MarkForReset() - self._contracts.MarkForReset() - self._storages.MarkForReset() - - # Standard Library - self.Register("System.Contract.GetStorageContext", self.Contract_GetStorageContext) - self.Register("System.Contract.Destroy", self.Contract_Destroy) - self.Register("System.Storage.Put", self.Storage_Put) - self.Register("System.Storage.Delete", self.Storage_Delete) - - # Neo specific self.Register("Neo.Asset.Create", self.Asset_Create) self.Register("Neo.Asset.Renew", self.Asset_Renew) - self.Register("Neo.Contract.Migrate", self.Contract_Migrate) - self.Register("Neo.Contract.Create", self.Contract_Create) + self.RegisterWithPrice("Neo.Asset.GetAssetId", self.Asset_GetAssetId, 1) + self.RegisterWithPrice("Neo.Asset.GetAssetType", self.Asset_GetAssetType, 1) + self.RegisterWithPrice("Neo.Asset.GetAmount", self.Asset_GetAmount, 1) + self.RegisterWithPrice("Neo.Asset.GetAvailable", self.Asset_GetAvailable, 1) + self.RegisterWithPrice("Neo.Asset.GetPrecision", self.Asset_GetPrecision, 1) + self.RegisterWithPrice("Neo.Asset.GetOwner", self.Asset_GetOwner, 1) + self.RegisterWithPrice("Neo.Asset.GetAdmin", self.Asset_GetAdmin, 1) + self.RegisterWithPrice("Neo.Asset.GetIssuer", self.Asset_GetIssuer, 1) - # Old - self.Register("Neo.Contract.GetStorageContext", self.Contract_GetStorageContext) - self.Register("Neo.Contract.Destroy", self.Contract_Destroy) + self.Register("Neo.Contract.Create", self.Contract_Create) + self.Register("Neo.Contract.Migrate", self.Contract_Migrate) + self.RegisterWithPrice("Neo.Contract.Destroy", self.Contract_Destroy, 1) + self.RegisterWithPrice("Neo.Contract.GetScript", self.Contract_GetScript, 1) + self.RegisterWithPrice("Neo.Contract.IsPayable", self.Contract_IsPayable, 1) + self.RegisterWithPrice("Neo.Contract.GetStorageContext", self.Contract_GetStorageContext, 1) + + self.RegisterWithPrice("Neo.Storage.GetContext", self.Storage_GetContext, 1) + self.RegisterWithPrice("Neo.Storage.GetReadOnlyContext", self.Storage_GetReadOnlyContext, 1) + self.RegisterWithPrice("Neo.Storage.Get", self.Storage_Get, 100) self.Register("Neo.Storage.Put", self.Storage_Put) - self.Register("Neo.Storage.Delete", self.Storage_Delete) - - # Very old - self.Register("AntShares.Account.SetVotes", self.Deprecated_Method) + self.RegisterWithPrice("Neo.Storage.Delete", self.Storage_Delete, 100) + self.RegisterWithPrice("Neo.Storage.Find", self.Storage_Find, 1) + self.RegisterWithPrice("Neo.StorageContext.AsReadOnly", self.StorageContext_AsReadOnly, 1) + + self.RegisterWithPrice("Neo.Enumerator.Create", self.Enumerator_Create, 1) + self.RegisterWithPrice("Neo.Enumerator.Next", self.Enumerator_Next, 1) + self.RegisterWithPrice("Neo.Enumerator.Value", self.Enumerator_Value, 1) + self.RegisterWithPrice("Neo.Enumerator.Concat", self.Enumerator_Concat, 1) + self.RegisterWithPrice("Neo.Iterator.Create", self.Iterator_Create, 1) + self.RegisterWithPrice("Neo.Iterator.Key", self.Iterator_Key, 1) + self.RegisterWithPrice("Neo.Iterator.Keys", self.Iterator_Keys, 1) + self.RegisterWithPrice("Neo.Iterator.Values", self.Iterator_Values, 1) + self.RegisterWithPrice("Neo.Iterator.Concat", self.Iterator_Concat, 1) + + # Aliases + self.RegisterWithPrice("Neo.Iterator.Next", self.Enumerator_Next, 1) + self.RegisterWithPrice("Neo.Iterator.Value", self.Enumerator_Value, 1) + + # Old APIs + self.RegisterWithPrice("AntShares.Runtime.CheckWitness", self.Runtime_CheckWitness, 200) + self.RegisterWithPrice("AntShares.Runtime.Notify", self.Runtime_Notify, 1) + self.RegisterWithPrice("AntShares.Runtime.Log", self.Runtime_Log, 1) + self.RegisterWithPrice("AntShares.Blockchain.GetHeight", self.Blockchain_GetHeight, 1) + self.RegisterWithPrice("AntShares.Blockchain.GetHeader", self.Blockchain_GetHeader, 100) + self.RegisterWithPrice("AntShares.Blockchain.GetBlock", self.Blockchain_GetBlock, 200) + self.RegisterWithPrice("AntShares.Blockchain.GetTransaction", self.Blockchain_GetTransaction, 100) + self.RegisterWithPrice("AntShares.Blockchain.GetAccount", self.Blockchain_GetAccount, 100) + self.RegisterWithPrice("AntShares.Blockchain.GetValidators", self.Blockchain_GetValidators, 200) + self.RegisterWithPrice("AntShares.Blockchain.GetAsset", self.Blockchain_GetAsset, 100) + self.RegisterWithPrice("AntShares.Blockchain.GetContract", self.Blockchain_GetContract, 100) + self.RegisterWithPrice("AntShares.Header.GetHash", self.Header_GetHash, 1) + self.RegisterWithPrice("AntShares.Header.GetVersion", self.Header_GetVersion, 1) + self.RegisterWithPrice("AntShares.Header.GetPrevHash", self.Header_GetPrevHash, 1) + self.RegisterWithPrice("AntShares.Header.GetMerkleRoot", self.Header_GetMerkleRoot, 1) + self.RegisterWithPrice("AntShares.Header.GetTimestamp", self.Header_GetTimestamp, 1) + self.RegisterWithPrice("AntShares.Header.GetConsensusData", self.Header_GetConsensusData, 1) + self.RegisterWithPrice("AntShares.Header.GetNextConsensus", self.Header_GetNextConsensus, 1) + self.RegisterWithPrice("AntShares.Block.GetTransactionCount", self.Block_GetTransactionCount, 1) + self.RegisterWithPrice("AntShares.Block.GetTransactions", self.Block_GetTransactions, 1) + self.RegisterWithPrice("AntShares.Block.GetTransaction", self.Block_GetTransaction, 1) + self.RegisterWithPrice("AntShares.Transaction.GetHash", self.Transaction_GetHash, 1) + self.RegisterWithPrice("AntShares.Transaction.GetType", self.Transaction_GetType, 1) + self.RegisterWithPrice("AntShares.Transaction.GetAttributes", self.Transaction_GetAttributes, 1) + self.RegisterWithPrice("AntShares.Transaction.GetInputs", self.Transaction_GetInputs, 1) + self.RegisterWithPrice("AntShares.Transaction.GetOutpus", self.Transaction_GetOutputs, 1) + self.RegisterWithPrice("AntShares.Transaction.GetReferences", self.Transaction_GetReferences, 200) + self.RegisterWithPrice("AntShares.Attribute.GetData", self.Attribute_GetData, 1) + self.RegisterWithPrice("AntShares.Attribute.GetUsage", self.Attribute_GetUsage, 1) + self.RegisterWithPrice("AntShares.Input.GetHash", self.Input_GetHash, 1) + self.RegisterWithPrice("AntShares.Input.GetIndex", self.Input_GetIndex, 1) + self.RegisterWithPrice("AntShares.Output.GetAssetId", self.Output_GetAssetId, 1) + self.RegisterWithPrice("AntShares.Output.GetValue", self.Output_GetValue, 1) + self.RegisterWithPrice("AntShares.Output.GetScriptHash", self.Output_GetScriptHash, 1) + self.RegisterWithPrice("AntShares.Account.GetVotes", self.Account_GetVotes, 1) + self.RegisterWithPrice("AntShares.Account.GetBalance", self.Account_GetBalance, 1) + self.RegisterWithPrice("AntShares.Account.GetScriptHash", self.Account_GetScriptHash, 1) self.Register("AntShares.Asset.Create", self.Asset_Create) self.Register("AntShares.Asset.Renew", self.Asset_Renew) + self.RegisterWithPrice("AntShares.Asset.GetAssetId", self.Asset_GetAssetId, 1) + self.RegisterWithPrice("AntShares.Asset.GetAssetType", self.Asset_GetAssetType, 1) + self.RegisterWithPrice("AntShares.Asset.GetAmount", self.Asset_GetAmount, 1) + self.RegisterWithPrice("AntShares.Asset.GetAvailable", self.Asset_GetAvailable, 1) + self.RegisterWithPrice("AntShares.Asset.GetPrecision", self.Asset_GetPrecision, 1) + self.RegisterWithPrice("AntShares.Asset.GetOwner", self.Asset_GetOwner, 1) + self.RegisterWithPrice("AntShares.Asset.GetAdmin", self.Asset_GetAdmin, 1) + self.RegisterWithPrice("AntShares.Asset.GetIssuer", self.Asset_GetIssuer, 1) self.Register("AntShares.Contract.Create", self.Contract_Create) self.Register("AntShares.Contract.Migrate", self.Contract_Migrate) - self.Register("AntShares.Contract.GetStorageContext", self.Contract_GetStorageContext) - self.Register("AntShares.Contract.Destroy", self.Contract_Destroy) + self.RegisterWithPrice("AntShares.Contract.Destroy", self.Contract_Destroy, 1) + self.RegisterWithPrice("AntShares.Contract.GetScript", self.Contract_GetScript, 1) + self.RegisterWithPrice("AntShares.Contract.GetStorageContext", self.Contract_GetStorageContext, 1) + self.RegisterWithPrice("AntShares.Storage.GetContext", self.Storage_GetContext, 1) + self.RegisterWithPrice("AntShares.Storage.Get", self.Storage_Get, 100) self.Register("AntShares.Storage.Put", self.Storage_Put) - self.Register("AntShares.Storage.Delete", self.Storage_Delete) + self.RegisterWithPrice("Neo.Storage.Delete", self.Storage_Delete, 100) def ExecutionCompleted(self, engine, success, error=None): @@ -107,6 +240,261 @@ def TestCommit(self): def Deprecated_Method(self, engine): logger.debug("Method No Longer operational") + def Blockchain_GetAccount(self, engine: ExecutionEngine): + hash = UInt160(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) + address = Crypto.ToAddress(hash).encode('utf-8') + + account = self.Accounts.GetOrAdd(address, new_instance=AccountState(script_hash=hash)) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(account)) + return True + + def Blockchain_GetValidators(self, engine: ExecutionEngine): + validators = Blockchain.Default().GetValidators() + + items = [StackItem(validator.encode_point(compressed=True)) for validator in validators] + + engine.CurrentContext.EvaluationStack.PushT(items) + + return True + + def Blockchain_GetAsset(self, engine: ExecutionEngine): + data = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() + asset = None + + if Blockchain.Default() is not None: + asset = self.Assets.TryGet(UInt256(data=data)) + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(asset)) + return True + + def Header_GetVersion(self, engine: ExecutionEngine): + header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if header is None: + return False + engine.CurrentContext.EvaluationStack.PushT(header.Version) + return True + + def Header_GetMerkleRoot(self, engine: ExecutionEngine): + header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if header is None: + return False + engine.CurrentContext.EvaluationStack.PushT(header.MerkleRoot.ToArray()) + return True + + def Header_GetConsensusData(self, engine: ExecutionEngine): + header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if header is None: + return False + engine.CurrentContext.EvaluationStack.PushT(header.ConsensusData) + return True + + def Header_GetNextConsensus(self, engine: ExecutionEngine): + header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if header is None: + return False + engine.CurrentContext.EvaluationStack.PushT(header.NextConsensus.ToArray()) + return True + + def Transaction_GetType(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if tx is None: + return False + + if isinstance(tx.Type, bytes): + engine.CurrentContext.EvaluationStack.PushT(tx.Type) + else: + engine.CurrentContext.EvaluationStack.PushT(tx.Type.to_bytes(1, 'little')) + return True + + def Transaction_GetAttributes(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if tx is None: + return False + + if len(tx.Attributes) > engine.maxArraySize: + return False + + attr = [StackItem.FromInterface(attr) for attr in tx.Attributes] + engine.CurrentContext.EvaluationStack.PushT(attr) + return True + + def Transaction_GetInputs(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if tx is None: + return False + + if len(tx.inputs) > engine.maxArraySize: + return False + + inputs = [StackItem.FromInterface(input) for input in tx.inputs] + engine.CurrentContext.EvaluationStack.PushT(inputs) + return True + + def Transaction_GetOutputs(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if tx is None: + return False + + if len(tx.outputs) > engine.maxArraySize: + return False + + outputs = [] + for output in tx.outputs: + stackoutput = StackItem.FromInterface(output) + outputs.append(stackoutput) + + engine.CurrentContext.EvaluationStack.PushT(outputs) + return True + + def Transaction_GetReferences(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if tx is None: + return False + + if len(tx.inputs) > engine.maxArraySize: + return False + + refs = [StackItem.FromInterface(tx.References[input]) for input in tx.inputs] + + engine.CurrentContext.EvaluationStack.PushT(refs) + return True + + def Transaction_GetUnspentCoins(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if tx is None: + return False + + outputs = Blockchain.Default().GetAllUnspent(tx.Hash) + if len(outputs) > engine.maxArraySize: + return False + + refs = [StackItem.FromInterface(unspent) for unspent in outputs] + engine.CurrentContext.EvaluationStack.PushT(refs) + return True + + def Transaction_GetWitnesses(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if tx is None: + return False + + if len(tx.scripts) > engine.maxArraySize: + return False + + witnesses = [StackItem.FromInterface(s) for s in tx.scripts] + engine.CurrentContext.EvaluationStack.PushT(witnesses) + return True + + def InvocationTransaction_GetScript(self, engine: ExecutionEngine): + tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if tx is None: + return False + engine.CurrentContext.EvaluationStack.PushT(tx.Script) + return True + + def Witness_GetVerificationScript(self, engine: ExecutionEngine): + witness = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if witness is None: + return False + engine.CurrentContext.EvaluationStack.PushT(witness.VerificationScript) + return True + + def Attribute_GetUsage(self, engine: ExecutionEngine): + attr = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if attr is None: + return False + engine.CurrentContext.EvaluationStack.PushT(attr.Usage) + return True + + def Attribute_GetData(self, engine: ExecutionEngine): + attr = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if attr is None: + return False + engine.CurrentContext.EvaluationStack.PushT(attr.Data) + return True + + def Input_GetHash(self, engine: ExecutionEngine): + input = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if input is None: + return False + engine.CurrentContext.EvaluationStack.PushT(input.PrevHash.ToArray()) + return True + + def Input_GetIndex(self, engine: ExecutionEngine): + input = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if input is None: + return False + + engine.CurrentContext.EvaluationStack.PushT(int(input.PrevIndex)) + return True + + def Output_GetAssetId(self, engine: ExecutionEngine): + output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if output is None: + return False + + engine.CurrentContext.EvaluationStack.PushT(output.AssetId.ToArray()) + return True + + def Output_GetValue(self, engine: ExecutionEngine): + output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if output is None: + return False + + engine.CurrentContext.EvaluationStack.PushT(output.Value.GetData()) + return True + + def Output_GetScriptHash(self, engine: ExecutionEngine): + output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + if output is None: + return False + + engine.CurrentContext.EvaluationStack.PushT(output.ScriptHash.ToArray()) + return True + + def Account_GetScriptHash(self, engine: ExecutionEngine): + account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if account is None: + return False + engine.CurrentContext.EvaluationStack.PushT(account.ScriptHash.ToArray()) + return True + + def Account_GetVotes(self, engine: ExecutionEngine): + account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if account is None: + return False + + votes = [StackItem.FromInterface(v.EncodePoint(True)) for v in account.Votes] + engine.CurrentContext.EvaluationStack.PushT(votes) + return True + + def Account_GetBalance(self, engine: ExecutionEngine): + account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + assetId = UInt256(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) + + if account is None: + return False + balance = account.BalanceFor(assetId) + engine.CurrentContext.EvaluationStack.PushT(balance.GetData()) + return True + + def Account_IsStandard(self, engine: ExecutionEngine): + # TODO: implement + # contract_hash = UInt160(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) + # contract = self._contracts.TryGet(contract_hash.ToBytes()) + # + # bool isStandard = contract is null | | contract.Script.IsStandardContract(); + # engine.CurrentContext.EvaluationStack.Push(isStandard); + # return true; + logger.error("Account_IsStandard not implemented!") + return False + def Asset_Create(self, engine: ExecutionEngine): tx = engine.ScriptContainer @@ -208,6 +596,62 @@ def Asset_Renew(self, engine: ExecutionEngine): return True + def Asset_GetAssetId(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.AssetId.ToArray()) + return True + + def Asset_GetAssetType(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.AssetType) + return True + + def Asset_GetAmount(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Amount.GetData()) + return True + + def Asset_GetAvailable(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Available.GetData()) + return True + + def Asset_GetPrecision(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Precision) + return True + + def Asset_GetOwner(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Owner.EncodePoint(True)) + return True + + def Asset_GetAdmin(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Admin.ToArray()) + return True + + def Asset_GetIssuer(self, engine: ExecutionEngine): + asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if asset is None: + return False + engine.CurrentContext.EvaluationStack.PushT(asset.Issuer.ToArray()) + return True + def Contract_Create(self, engine: ExecutionEngine): script = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() @@ -333,107 +777,114 @@ def Contract_Migrate(self, engine: ExecutionEngine): return self.Contract_Destroy(engine) - def Contract_GetStorageContext(self, engine): + def Contract_GetScript(self, engine: ExecutionEngine): + contract = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if contract is None: + return False + engine.CurrentContext.EvaluationStack.PushT(contract.Code.Script) + return True + def Contract_IsPayable(self, engine: ExecutionEngine): contract = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if contract is None: + return False + engine.CurrentContext.EvaluationStack.PushT(contract.Payable) + return True - shash = contract.Code.ScriptHash() + def Storage_Find(self, engine: ExecutionEngine): + context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if context is None: + return False - if shash.ToBytes() in self._contracts_created: + if not self.CheckStorageContext(context): + return False - created = self._contracts_created[shash.ToBytes()] + prefix = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() + prefix = context.ScriptHash.ToArray() + prefix - if created == UInt160(data=engine.CurrentContext.ScriptHash()): - context = StorageContext(script_hash=shash) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(context)) + iterator = self.Storages.TryFind(prefix) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(iterator)) - return True + return True + def Enumerator_Create(self, engine: ExecutionEngine): + item = engine.CurrentContext.EvaluationStack.Pop() + if isinstance(item, Array): + enumerator = ArrayWrapper(item) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(enumerator)) + return True return False - def Contract_Destroy(self, engine): - hash = UInt160(data=engine.CurrentContext.ScriptHash()) - - contract = self._contracts.TryGet(hash.ToBytes()) - - if contract is not None: - - self._contracts.Remove(hash.ToBytes()) - - if contract.HasStorage: - - for pair in self._storages.Find(hash.ToBytes()): - self._storages.Remove(pair.Key) - - self.events_to_dispatch.append( - SmartContractEvent(SmartContractEvent.CONTRACT_DESTROY, ContractParameter(ContractParameterType.InteropInterface, contract), - hash, Blockchain.Default().Height + 1, - engine.ScriptContainer.Hash if engine.ScriptContainer else None, - test_mode=engine.testMode)) + def Enumerator_Next(self, engine: ExecutionEngine): + item = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item is None: + return False + engine.CurrentContext.EvaluationStack.PushT(item.Next()) return True - def Storage_Put(self, engine: ExecutionEngine): - - context = None - try: - - context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - except Exception as e: - logger.error("Storage Context Not found on stack") + def Enumerator_Value(self, engine: ExecutionEngine): + item = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item is None: return False - if not self.CheckStorageContext(context): - return False + engine.CurrentContext.EvaluationStack.PushT(item.Value()) + return True - key = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() - if len(key) > 1024: + def Enumerator_Concat(self, engine: ExecutionEngine): + item1 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item1 is None: return False - value = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() - - new_item = StorageItem(value=value) - storage_key = StorageKey(script_hash=context.ScriptHash, key=key) - item = self._storages.ReplaceOrAdd(storage_key.ToArray(), new_item) - - keystr = key - valStr = bytearray(item.Value) + item2 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item2 is None: + return False - if len(key) == 20: - keystr = Crypto.ToAddress(UInt160(data=key)) + result = ConcatenatedEnumerator(item1, item2) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(result)) + return True - try: - valStr = int.from_bytes(valStr, 'little') - except Exception as e: - pass + def Iterator_Create(self, engine: ExecutionEngine): + item = engine.CurrentContext.EvaluationStack.Pop() + if isinstance(item, Map): + iterator = MapWrapper(item) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(iterator)) + return True + return False - self.events_to_dispatch.append( - SmartContractEvent(SmartContractEvent.STORAGE_PUT, ContractParameter(ContractParameterType.String, '%s -> %s' % (keystr, valStr)), - context.ScriptHash, Blockchain.Default().Height + 1, - engine.ScriptContainer.Hash if engine.ScriptContainer else None, - test_mode=engine.testMode)) + def Iterator_Key(self, engine: ExecutionEngine): + iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if iterator is None: + return False + engine.CurrentContext.EvaluationStack.PushT(iterator.Key()) return True - def Storage_Delete(self, engine: ExecutionEngine): - - context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if not self.CheckStorageContext(context): + def Iterator_Keys(self, engine: ExecutionEngine): + iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if iterator is None: return False + wrapper = StackItem.FromInterface(KeysWrapper(iterator)) + engine.CurrentContext.EvaluationStack.PushT(wrapper) + return True - key = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() - - storage_key = StorageKey(script_hash=context.ScriptHash, key=key) + def Iterator_Values(self, engine: ExecutionEngine): + iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if iterator is None: + return False - keystr = key - if len(key) == 20: - keystr = Crypto.ToAddress(UInt160(data=key)) + wrapper = StackItem.FromInterface(ValuesWrapper(iterator)) + engine.CurrentContext.EvaluationStack.PushT(wrapper) + return True - self.events_to_dispatch.append(SmartContractEvent(SmartContractEvent.STORAGE_DELETE, ContractParameter(ContractParameterType.String, keystr), - context.ScriptHash, Blockchain.Default().Height + 1, - engine.ScriptContainer.Hash if engine.ScriptContainer else None, - test_mode=engine.testMode)) + def Iterator_Concat(self, engine: ExecutionEngine): + item1 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item1 is None: + return False - self._storages.Remove(storage_key.ToArray()) + item2 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + if item2 is None: + return False + result = ConcatenatedIterator(item1, item2) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(result)) return True diff --git a/neo/SmartContract/StateReader.py b/neo/SmartContract/StateReader.py index 78be62c7c..a3c20d61a 100644 --- a/neo/SmartContract/StateReader.py +++ b/neo/SmartContract/StateReader.py @@ -18,9 +18,6 @@ from neo.Core.IO.BinaryReader import BinaryReader from neo.Core.IO.BinaryWriter import BinaryWriter from neo.IO.MemoryStream import StreamManager -from neo.SmartContract.Iterable.Wrapper import ArrayWrapper, MapWrapper -from neo.SmartContract.Iterable import KeysWrapper, ValuesWrapper -from neo.SmartContract.Iterable.ConcatenatedEnumerator import ConcatenatedEnumerator from neo.Implementations.Blockchains.LevelDB.DBPrefix import DBPrefix from neo.Core.State.ContractState import ContractState from neo.Core.State.AccountState import AccountState @@ -32,18 +29,6 @@ class StateReader(InteropService): - notifications = None - - events_to_dispatch = [] - - __Instance = None - - _hashes_for_verifying = None - - _accounts = None - _assets = None - _contracts = None - _storages = None @property def Accounts(self): @@ -69,11 +54,9 @@ def Storages(self): self._storages = Blockchain.Default().GetStates(DBPrefix.ST_Storage, StorageItem) return self._storages - @staticmethod - def Instance(): - if StateReader.__Instance is None: - StateReader.__Instance = StateReader() - return StateReader.__Instance + def RegisterWithPrice(self, method, func, price): + self._dictionary[method] = func + self.prices.update({hash(method): price}) def __init__(self): @@ -81,169 +64,53 @@ def __init__(self): self.notifications = [] self.events_to_dispatch = [] + self.prices = {} + self._hashes_for_verifying = None + self._accounts = None + self._assets = None + self._contracts = None + self._storages = None + + # TODO: move ExecutionEngine calls here as well from /neo/VM/InteropService/ # Standard Library - self.Register("System.Runtime.GetTrigger", self.Runtime_GetTrigger) - self.Register("System.Runtime.CheckWitness", self.Runtime_CheckWitness) - self.Register("System.Runtime.Notify", self.Runtime_Notify) - self.Register("System.Runtime.Log", self.Runtime_Log) - self.Register("System.Runtime.GetTime", self.Runtime_GetCurrentTime) - self.Register("System.Runtime.Serialize", self.Runtime_Serialize) - self.Register("System.Runtime.Deserialize", self.Runtime_Deserialize) - self.Register("System.Blockchain.GetHeight", self.Blockchain_GetHeight) - self.Register("System.Blockchain.GetHeader", self.Blockchain_GetHeader) - self.Register("System.Blockchain.GetBlock", self.Blockchain_GetBlock) - self.Register("System.Blockchain.GetTransaction", self.Blockchain_GetTransaction) - self.Register("System.Blockchain.GetTransactionHeight", self.Blockchain_GetTransactionHeight) - self.Register("System.Blockchain.GetContract", self.Blockchain_GetContract) - self.Register("System.Header.GetIndex", self.Header_GetIndex) - self.Register("System.Header.GetHash", self.Header_GetHash) - self.Register("System.Header.GetVersion", self.Header_GetVersion) - self.Register("System.Header.GetPrevHash", self.Header_GetPrevHash) - self.Register("System.Header.GetTimestamp", self.Header_GetTimestamp) - self.Register("System.Block.GetTransactionCount", self.Block_GetTransactionCount) - self.Register("System.Block.GetTransactions", self.Block_GetTransactions) - self.Register("System.Block.GetTransaction", self.Block_GetTransaction) - self.Register("System.Transaction.GetHash", self.Transaction_GetHash) - self.Register("System.Storage.GetContext", self.Storage_GetContext) - self.Register("System.Storage.GetReadOnlyContext", self.Storage_GetReadOnlyContext) - self.Register("System.Storage.Get", self.Storage_Get) - self.Register("System.StorageContext.AsReadOnly", self.StorageContext_AsReadOnly) - - # Neo Specific - self.Register("Neo.Blockchain.GetAccount", self.Blockchain_GetAccount) - self.Register("Neo.Blockchain.GetValidators", self.Blockchain_GetValidators) - self.Register("Neo.Blockchain.GetAsset", self.Blockchain_GetAsset) - self.Register("Neo.Header.GetMerkleRoot", self.Header_GetMerkleRoot) - self.Register("Neo.Header.GetConsensusData", self.Header_GetConsensusData) - self.Register("Neo.Header.GetNextConsensus", self.Header_GetNextConsensus) - self.Register("Neo.Transaction.GetType", self.Transaction_GetType) - self.Register("Neo.Transaction.GetAttributes", self.Transaction_GetAttributes) - self.Register("Neo.Transaction.GetInputs", self.Transaction_GetInputs) - self.Register("Neo.Transaction.GetOutputs", self.Transaction_GetOutputs) - self.Register("Neo.Transaction.GetReferences", self.Transaction_GetReferences) - self.Register("Neo.Transaction.GetUnspentCoins", self.Transaction_GetUnspentCoins) - self.Register("Neo.Transaction.GetWitnesses", self.Transaction_GetWitnesses) - self.Register("Neo.InvocationTransaction.GetScript", self.InvocationTransaction_GetScript) - self.Register("Neo.Witness.GetVerificationScript", self.Witness_GetVerificationScript) - self.Register("Neo.Attribute.GetUsage", self.Attribute_GetUsage) - self.Register("Neo.Attribute.GetData", self.Attribute_GetData) - self.Register("Neo.Input.GetHash", self.Input_GetHash) - self.Register("Neo.Input.GetIndex", self.Input_GetIndex) - self.Register("Neo.Output.GetAssetId", self.Output_GetAssetId) - self.Register("Neo.Output.GetValue", self.Output_GetValue) - self.Register("Neo.Output.GetScriptHash", self.Output_GetScriptHash) - self.Register("Neo.Account.GetScriptHash", self.Account_GetScriptHash) - self.Register("Neo.Account.GetVotes", self.Account_GetVotes) - self.Register("Neo.Account.GetBalance", self.Account_GetBalance) - self.Register("Neo.Asset.GetAssetId", self.Asset_GetAssetId) - self.Register("Neo.Asset.GetAssetType", self.Asset_GetAssetType) - self.Register("Neo.Asset.GetAmount", self.Asset_GetAmount) - self.Register("Neo.Asset.GetAvailable", self.Asset_GetAvailable) - self.Register("Neo.Asset.GetPrecision", self.Asset_GetPrecision) - self.Register("Neo.Asset.GetOwner", self.Asset_GetOwner) - self.Register("Neo.Asset.GetAdmin", self.Asset_GetAdmin) - self.Register("Neo.Asset.GetIssuer", self.Asset_GetIssuer) - self.Register("Neo.Contract.GetScript", self.Contract_GetScript) - self.Register("Neo.Contract.IsPayable", self.Contract_IsPayable) - self.Register("Neo.Storage.Find", self.Storage_Find) - self.Register("Neo.Enumerator.Create", self.Enumerator_Create) - self.Register("Neo.Enumerator.Next", self.Enumerator_Next) - self.Register("Neo.Enumerator.Value", self.Enumerator_Value) - self.Register("Neo.Enumerator.Concat", self.Enumerator_Concat) - self.Register("Neo.Iterator.Create", self.Iterator_Create) - self.Register("Neo.Iterator.Key", self.Iterator_Key) - self.Register("Neo.Iterator.Keys", self.Iterator_Keys) - self.Register("Neo.Iterator.Values", self.Iterator_Values) - - # Old Iterator aliases - self.Register("Neo.Iterator.Next", self.Enumerator_Next) - self.Register("Neo.Iterator.Value", self.Enumerator_Value) - - # Old API - # Standard Library - self.Register("Neo.Runtime.GetTrigger", self.Runtime_GetTrigger) - self.Register("Neo.Runtime.CheckWitness", self.Runtime_CheckWitness) - self.Register("Neo.Runtime.Notify", self.Runtime_Notify) - self.Register("Neo.Runtime.Log", self.Runtime_Log) - self.Register("Neo.Runtime.GetTime", self.Runtime_GetCurrentTime) - self.Register("Neo.Runtime.Serialize", self.Runtime_Serialize) - self.Register("Neo.Runtime.Deserialize", self.Runtime_Deserialize) - self.Register("Neo.Blockchain.GetHeight", self.Blockchain_GetHeight) - self.Register("Neo.Blockchain.GetHeader", self.Blockchain_GetHeader) - self.Register("Neo.Blockchain.GetBlock", self.Blockchain_GetBlock) - self.Register("Neo.Blockchain.GetTransaction", self.Blockchain_GetTransaction) - self.Register("Neo.Blockchain.GetTransactionHeight", self.Blockchain_GetTransactionHeight) - self.Register("Neo.Blockchain.GetContract", self.Blockchain_GetContract) - self.Register("Neo.Header.GetIndex", self.Header_GetIndex) - self.Register("Neo.Header.GetHash", self.Header_GetHash) - self.Register("Neo.Header.GetVersion", self.Header_GetVersion) - self.Register("Neo.Header.GetPrevHash", self.Header_GetPrevHash) - self.Register("Neo.Header.GetTimestamp", self.Header_GetTimestamp) - self.Register("Neo.Block.GetTransactionCount", self.Block_GetTransactionCount) - self.Register("Neo.Block.GetTransactions", self.Block_GetTransactions) - self.Register("Neo.Block.GetTransaction", self.Block_GetTransaction) - self.Register("Neo.Transaction.GetHash", self.Transaction_GetHash) - self.Register("Neo.Storage.GetContext", self.Storage_GetContext) - self.Register("Neo.Storage.GetReadOnlyContext", self.Storage_GetReadOnlyContext) - self.Register("Neo.Storage.Get", self.Storage_Get) - self.Register("Neo.StorageContext.AsReadOnly", self.StorageContext_AsReadOnly) - - # Very OLD API - self.Register("AntShares.Runtime.GetTrigger", self.Runtime_GetTrigger) - self.Register("AntShares.Runtime.CheckWitness", self.Runtime_CheckWitness) - self.Register("AntShares.Runtime.Notify", self.Runtime_Notify) - self.Register("AntShares.Runtime.Log", self.Runtime_Log) - self.Register("AntShares.Blockchain.GetHeight", self.Blockchain_GetHeight) - self.Register("AntShares.Blockchain.GetHeader", self.Blockchain_GetHeader) - self.Register("AntShares.Blockchain.GetBlock", self.Blockchain_GetBlock) - self.Register("AntShares.Blockchain.GetTransaction", self.Blockchain_GetTransaction) - self.Register("AntShares.Blockchain.GetAccount", self.Blockchain_GetAccount) - self.Register("AntShares.Blockchain.GetValidators", self.Blockchain_GetValidators) - self.Register("AntShares.Blockchain.GetAsset", self.Blockchain_GetAsset) - self.Register("AntShares.Blockchain.GetContract", self.Blockchain_GetContract) - self.Register("AntShares.Header.GetHash", self.Header_GetHash) - self.Register("AntShares.Header.GetVersion", self.Header_GetVersion) - self.Register("AntShares.Header.GetPrevHash", self.Header_GetPrevHash) - self.Register("AntShares.Header.GetMerkleRoot", self.Header_GetMerkleRoot) - self.Register("AntShares.Header.GetTimestamp", self.Header_GetTimestamp) - self.Register("AntShares.Header.GetConsensusData", self.Header_GetConsensusData) - self.Register("AntShares.Header.GetNextConsensus", self.Header_GetNextConsensus) - self.Register("AntShares.Block.GetTransactionCount", self.Block_GetTransactionCount) - self.Register("AntShares.Block.GetTransactions", self.Block_GetTransactions) - self.Register("AntShares.Block.GetTransaction", self.Block_GetTransaction) - self.Register("AntShares.Transaction.GetHash", self.Transaction_GetHash) - self.Register("AntShares.Transaction.GetType", self.Transaction_GetType) - self.Register("AntShares.Transaction.GetAttributes", self.Transaction_GetAttributes) - self.Register("AntShares.Transaction.GetInputs", self.Transaction_GetInputs) - self.Register("AntShares.Transaction.GetOutpus", self.Transaction_GetOutputs) - self.Register("AntShares.Transaction.GetReferences", self.Transaction_GetReferences) - self.Register("AntShares.Attribute.GetData", self.Attribute_GetData) - self.Register("AntShares.Attribute.GetUsage", self.Attribute_GetUsage) - self.Register("AntShares.Input.GetHash", self.Input_GetHash) - self.Register("AntShares.Input.GetIndex", self.Input_GetIndex) - self.Register("AntShares.Output.GetAssetId", self.Output_GetAssetId) - self.Register("AntShares.Output.GetValue", self.Output_GetValue) - self.Register("AntShares.Output.GetScriptHash", self.Output_GetScriptHash) - self.Register("AntShares.Account.GetVotes", self.Account_GetVotes) - self.Register("AntShares.Account.GetBalance", self.Account_GetBalance) - self.Register("AntShares.Account.GetScriptHash", self.Account_GetScriptHash) - self.Register("AntShares.Asset.GetAssetId", self.Asset_GetAssetId) - self.Register("AntShares.Asset.GetAssetType", self.Asset_GetAssetType) - self.Register("AntShares.Asset.GetAmount", self.Asset_GetAmount) - self.Register("AntShares.Asset.GetAvailable", self.Asset_GetAvailable) - self.Register("AntShares.Asset.GetPrecision", self.Asset_GetPrecision) - self.Register("AntShares.Asset.GetOwner", self.Asset_GetOwner) - self.Register("AntShares.Asset.GetAdmin", self.Asset_GetAdmin) - self.Register("AntShares.Asset.GetIssuer", self.Asset_GetIssuer) - self.Register("AntShares.Contract.GetScript", self.Contract_GetScript) - self.Register("AntShares.Storage.GetContext", self.Storage_GetContext) - self.Register("AntShares.Storage.Get", self.Storage_Get) + self.RegisterWithPrice("System.Runtime.Platform", self.Runtime_Platform, 1) + self.RegisterWithPrice("System.Runtime.GetTrigger", self.Runtime_GetTrigger, 1) + self.RegisterWithPrice("System.Runtime.CheckWitness", self.Runtime_CheckWitness, 200) + self.RegisterWithPrice("System.Runtime.Notify", self.Runtime_Notify, 1) + self.RegisterWithPrice("System.Runtime.Log", self.Runtime_Log, 1) + self.RegisterWithPrice("System.Runtime.GetTime", self.Runtime_GetCurrentTime, 1) + self.RegisterWithPrice("System.Runtime.Serialize", self.Runtime_Serialize, 1) + self.RegisterWithPrice("System.Runtime.Deserialize", self.Runtime_Deserialize, 1) + self.RegisterWithPrice("System.Blockchain.GetHeight", self.Blockchain_GetHeight, 1) + self.RegisterWithPrice("System.Blockchain.GetHeader", self.Blockchain_GetHeader, 100) + self.RegisterWithPrice("System.Blockchain.GetBlock", self.Blockchain_GetBlock, 200) + self.RegisterWithPrice("System.Blockchain.GetTransaction", self.Blockchain_GetTransaction, 200) + self.RegisterWithPrice("System.Blockchain.GetTransactionHeight", self.Blockchain_GetTransactionHeight, 100) + self.RegisterWithPrice("System.Blockchain.GetContract", self.Blockchain_GetContract, 100) + self.RegisterWithPrice("System.Header.GetIndex", self.Header_GetIndex, 1) + self.RegisterWithPrice("System.Header.GetHash", self.Header_GetHash, 1) + self.RegisterWithPrice("System.Header.GetPrevHash", self.Header_GetPrevHash, 1) + self.RegisterWithPrice("System.Header.GetTimestamp", self.Header_GetTimestamp, 1) + self.RegisterWithPrice("System.Block.GetTransactionCount", self.Block_GetTransactionCount, 1) + self.RegisterWithPrice("System.Block.GetTransactions", self.Block_GetTransactions, 1) + self.RegisterWithPrice("System.Block.GetTransaction", self.Block_GetTransaction, 1) + self.RegisterWithPrice("System.Transaction.GetHash", self.Transaction_GetHash, 1) + self.RegisterWithPrice("System.Storage.GetContext", self.Storage_GetContext, 1) + self.RegisterWithPrice("System.Storage.GetReadOnlyContext", self.Storage_GetReadOnlyContext, 1) + self.RegisterWithPrice("System.Storage.Get", self.Storage_Get, 100) + self.Register("System.Storage.Put", self.Storage_Put) + self.Register("System.Storage.PutEx", self.Storage_PutEx) + self.RegisterWithPrice("System.Storage.Delete", self.Storage_Delete, 100) + self.RegisterWithPrice("System.StorageContext.AsReadOnly", self.StorageContext_AsReadOnly, 1) def CheckStorageContext(self, context): if context is None: return False + if type(context) != StorageContext: + return False + contract = self.Contracts.TryGet(context.ScriptHash.ToBytes()) if contract is not None: @@ -252,6 +119,9 @@ def CheckStorageContext(self, context): return False + def GetPrice(self, hash: int): + return self.prices.get(hash, 0) + def ExecutionCompleted(self, engine, success, error=None): height = Blockchain.Default().Height + 1 tx_hash = None @@ -303,7 +173,8 @@ def ExecutionCompleted(self, engine, success, error=None): # If we do not add the eval stack, then exceptions that are raised in a contract # are not displayed to the event consumer - [payload.Value.append(ContractParameter.ToParameter(item)) for item in engine.CurrentContext.EvaluationStack.Items] + if engine._InvocationStack.Count > 1: + [payload.Value.append(ContractParameter.ToParameter(item)) for item in engine.CurrentContext.EvaluationStack.Items] if engine.Trigger == Application: self.events_to_dispatch.append( @@ -316,6 +187,10 @@ def ExecutionCompleted(self, engine, success, error=None): self.notifications = [] + def Runtime_Platform(self, engine): + engine.CurrentContext.EvaluationStack.PushT(b'\x4e\x45\x4f') # NEO + return True + def Runtime_GetTrigger(self, engine): engine.CurrentContext.EvaluationStack.PushT(engine.Trigger) @@ -419,12 +294,14 @@ def Runtime_Serialize(self, engine: ExecutionEngine): try: stack_item.Serialize(writer) except Exception as e: + StreamManager.ReleaseStream(ms) logger.error("Cannot serialize item %s: %s " % (stack_item, e)) return False ms.flush() if ms.tell() > engine.maxItemSize: + StreamManager.ReleaseStream(ms) return False retVal = ByteArray(ms.getvalue()) @@ -445,6 +322,8 @@ def Runtime_Deserialize(self, engine: ExecutionEngine): # can't deserialize type logger.error("%s " % e) return False + finally: + StreamManager.ReleaseStream(ms) return True def Blockchain_GetHeight(self, engine: ExecutionEngine): @@ -543,34 +422,6 @@ def Blockchain_GetTransactionHeight(self, engine: ExecutionEngine): engine.CurrentContext.EvaluationStack.PushT(height) return True - def Blockchain_GetAccount(self, engine: ExecutionEngine): - hash = UInt160(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) - address = Crypto.ToAddress(hash).encode('utf-8') - - account = self.Accounts.GetOrAdd(address, new_instance=AccountState(script_hash=hash)) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(account)) - return True - - def Blockchain_GetValidators(self, engine: ExecutionEngine): - validators = Blockchain.Default().GetValidators() - - items = [StackItem(validator.encode_point(compressed=True)) for validator in validators] - - engine.CurrentContext.EvaluationStack.PushT(items) - - return True - - def Blockchain_GetAsset(self, engine: ExecutionEngine): - data = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() - asset = None - - if Blockchain.Default() is not None: - asset = self.Assets.TryGet(UInt256(data=data)) - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(asset)) - return True - def Blockchain_GetContract(self, engine: ExecutionEngine): hash = UInt160(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) contract = self.Contracts.TryGet(hash.ToBytes()) @@ -594,13 +445,6 @@ def Header_GetHash(self, engine: ExecutionEngine): engine.CurrentContext.EvaluationStack.PushT(header.Hash.ToArray()) return True - def Header_GetVersion(self, engine: ExecutionEngine): - header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if header is None: - return False - engine.CurrentContext.EvaluationStack.PushT(header.Version) - return True - def Header_GetPrevHash(self, engine: ExecutionEngine): header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() if header is None: @@ -608,13 +452,6 @@ def Header_GetPrevHash(self, engine: ExecutionEngine): engine.CurrentContext.EvaluationStack.PushT(header.PrevHash.ToArray()) return True - def Header_GetMerkleRoot(self, engine: ExecutionEngine): - header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if header is None: - return False - engine.CurrentContext.EvaluationStack.PushT(header.MerkleRoot.ToArray()) - return True - def Header_GetTimestamp(self, engine: ExecutionEngine): header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() if header is None: @@ -623,20 +460,6 @@ def Header_GetTimestamp(self, engine: ExecutionEngine): return True - def Header_GetConsensusData(self, engine: ExecutionEngine): - header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if header is None: - return False - engine.CurrentContext.EvaluationStack.PushT(header.ConsensusData) - return True - - def Header_GetNextConsensus(self, engine: ExecutionEngine): - header = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if header is None: - return False - engine.CurrentContext.EvaluationStack.PushT(header.NextConsensus.ToArray()) - return True - def Block_GetTransactionCount(self, engine: ExecutionEngine): block = engine.CurrentContext.EvaluationStack.Pop().GetInterface() if block is None: @@ -675,264 +498,6 @@ def Transaction_GetHash(self, engine: ExecutionEngine): engine.CurrentContext.EvaluationStack.PushT(tx.Hash.ToArray()) return True - def Transaction_GetType(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if tx is None: - return False - - if isinstance(tx.Type, bytes): - engine.CurrentContext.EvaluationStack.PushT(tx.Type) - else: - engine.CurrentContext.EvaluationStack.PushT(tx.Type.to_bytes(1, 'little')) - return True - - def Transaction_GetAttributes(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if tx is None: - return False - - if len(tx.Attributes) > engine.maxArraySize: - return False - - attr = [StackItem.FromInterface(attr) for attr in tx.Attributes] - engine.CurrentContext.EvaluationStack.PushT(attr) - return True - - def Transaction_GetInputs(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if tx is None: - return False - - if len(tx.inputs) > engine.maxArraySize: - return False - - inputs = [StackItem.FromInterface(input) for input in tx.inputs] - engine.CurrentContext.EvaluationStack.PushT(inputs) - return True - - def Transaction_GetOutputs(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if tx is None: - return False - - if len(tx.outputs) > engine.maxArraySize: - return False - - outputs = [] - for output in tx.outputs: - stackoutput = StackItem.FromInterface(output) - outputs.append(stackoutput) - - engine.CurrentContext.EvaluationStack.PushT(outputs) - return True - - def Transaction_GetReferences(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if tx is None: - return False - - if len(tx.inputs) > engine.maxArraySize: - return False - - refs = [StackItem.FromInterface(tx.References[input]) for input in tx.inputs] - - engine.CurrentContext.EvaluationStack.PushT(refs) - return True - - def Transaction_GetUnspentCoins(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if tx is None: - return False - - outputs = Blockchain.Default().GetAllUnspent(tx.Hash) - if len(outputs) > engine.maxArraySize: - return False - - refs = [StackItem.FromInterface(unspent) for unspent in outputs] - engine.CurrentContext.EvaluationStack.PushT(refs) - return True - - def Transaction_GetWitnesses(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if tx is None: - return False - - if len(tx.scripts) > engine.maxArraySize: - return False - - witnesses = [StackItem.FromInterface(s) for s in tx.scripts] - engine.CurrentContext.EvaluationStack.PushT(witnesses) - return True - - def InvocationTransaction_GetScript(self, engine: ExecutionEngine): - tx = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if tx is None: - return False - engine.CurrentContext.EvaluationStack.PushT(tx.Script) - return True - - def Witness_GetVerificationScript(self, engine: ExecutionEngine): - witness = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if witness is None: - return False - engine.CurrentContext.EvaluationStack.PushT(witness.VerificationScript) - return True - - def Attribute_GetUsage(self, engine: ExecutionEngine): - attr = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if attr is None: - return False - engine.CurrentContext.EvaluationStack.PushT(attr.Usage) - return True - - def Attribute_GetData(self, engine: ExecutionEngine): - attr = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if attr is None: - return False - engine.CurrentContext.EvaluationStack.PushT(attr.Data) - return True - - def Input_GetHash(self, engine: ExecutionEngine): - input = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if input is None: - return False - engine.CurrentContext.EvaluationStack.PushT(input.PrevHash.ToArray()) - return True - - def Input_GetIndex(self, engine: ExecutionEngine): - input = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if input is None: - return False - - engine.CurrentContext.EvaluationStack.PushT(int(input.PrevIndex)) - return True - - def Output_GetAssetId(self, engine: ExecutionEngine): - output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if output is None: - return False - - engine.CurrentContext.EvaluationStack.PushT(output.AssetId.ToArray()) - return True - - def Output_GetValue(self, engine: ExecutionEngine): - output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if output is None: - return False - - engine.CurrentContext.EvaluationStack.PushT(output.Value.GetData()) - return True - - def Output_GetScriptHash(self, engine: ExecutionEngine): - output = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - - if output is None: - return False - - engine.CurrentContext.EvaluationStack.PushT(output.ScriptHash.ToArray()) - return True - - def Account_GetScriptHash(self, engine: ExecutionEngine): - account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if account is None: - return False - engine.CurrentContext.EvaluationStack.PushT(account.ScriptHash.ToArray()) - return True - - def Account_GetVotes(self, engine: ExecutionEngine): - account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if account is None: - return False - - votes = [StackItem.FromInterface(v.EncodePoint(True)) for v in account.Votes] - engine.CurrentContext.EvaluationStack.PushT(votes) - return True - - def Account_GetBalance(self, engine: ExecutionEngine): - account = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - assetId = UInt256(data=engine.CurrentContext.EvaluationStack.Pop().GetByteArray()) - - if account is None: - return False - balance = account.BalanceFor(assetId) - engine.CurrentContext.EvaluationStack.PushT(balance.GetData()) - return True - - def Asset_GetAssetId(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.AssetId.ToArray()) - return True - - def Asset_GetAssetType(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.AssetType) - return True - - def Asset_GetAmount(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Amount.GetData()) - return True - - def Asset_GetAvailable(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Available.GetData()) - return True - - def Asset_GetPrecision(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Precision) - return True - - def Asset_GetOwner(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Owner.EncodePoint(True)) - return True - - def Asset_GetAdmin(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Admin.ToArray()) - return True - - def Asset_GetIssuer(self, engine: ExecutionEngine): - asset = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if asset is None: - return False - engine.CurrentContext.EvaluationStack.PushT(asset.Issuer.ToArray()) - return True - - def Contract_GetScript(self, engine: ExecutionEngine): - contract = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if contract is None: - return False - engine.CurrentContext.EvaluationStack.PushT(contract.Code.Script) - return True - - def Contract_IsPayable(self, engine: ExecutionEngine): - contract = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if contract is None: - return False - engine.CurrentContext.EvaluationStack.PushT(contract.Payable) - return True - def Storage_GetContext(self, engine: ExecutionEngine): hash = UInt160(data=engine.CurrentContext.ScriptHash()) context = StorageContext(script_hash=hash) @@ -984,14 +549,6 @@ def Storage_Get(self, engine: ExecutionEngine): if item is not None: valStr = bytearray(item.Value) - if len(key) == 20: - keystr = Crypto.ToAddress(UInt160(data=key)) - - try: - valStr = int.from_bytes(valStr, 'little') - except Exception as e: - logger.error("Could not convert %s to number: %s " % (valStr, e)) - if item is not None: engine.CurrentContext.EvaluationStack.PushT(bytearray(item.Value)) @@ -1008,87 +565,111 @@ def Storage_Get(self, engine: ExecutionEngine): return True - def Storage_Find(self, engine: ExecutionEngine): - context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if context is None: + def Storage_Put(self, engine: ExecutionEngine): + + context = None + try: + context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + except Exception as e: + logger.error("Storage Context Not found on stack") return False if not self.CheckStorageContext(context): return False - prefix = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() - prefix = context.ScriptHash.ToArray() + prefix + key = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() + if len(key) > 1024: + return False + + value = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() + + new_item = StorageItem(value=value) + storage_key = StorageKey(script_hash=context.ScriptHash, key=key) + item = self._storages.ReplaceOrAdd(storage_key.ToArray(), new_item) - iterator = self.Storages.TryFind(prefix) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(iterator)) + keystr = key + valStr = bytearray(item.Value) + + if type(engine) == ExecutionEngine: + test_mode = False + else: + test_mode = engine.testMode + + self.events_to_dispatch.append( + SmartContractEvent(SmartContractEvent.STORAGE_PUT, ContractParameter(ContractParameterType.String, '%s -> %s' % (keystr, valStr)), + context.ScriptHash, Blockchain.Default().Height + 1, + engine.ScriptContainer.Hash if engine.ScriptContainer else None, + test_mode=test_mode)) return True - def Enumerator_Create(self, engine: ExecutionEngine): - item = engine.CurrentContext.EvaluationStack.Pop() - if isinstance(item, Array): - enumerator = ArrayWrapper(item) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(enumerator)) - return True + def Contract_GetStorageContext(self, engine): + + contract = engine.CurrentContext.EvaluationStack.Pop().GetInterface() + + shash = contract.Code.ScriptHash() + + if shash.ToBytes() in self._contracts_created: + + created = self._contracts_created[shash.ToBytes()] + + if created == UInt160(data=engine.CurrentContext.ScriptHash()): + context = StorageContext(script_hash=shash) + engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(context)) + + return True + return False - def Enumerator_Next(self, engine: ExecutionEngine): - item = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if item is None: - return False - engine.CurrentContext.EvaluationStack.PushT(item.Next()) - return True + def Contract_Destroy(self, engine): + hash = UInt160(data=engine.CurrentContext.ScriptHash()) - def Enumerator_Value(self, engine: ExecutionEngine): - item = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if item is None: - return False + contract = self._contracts.TryGet(hash.ToBytes()) - engine.CurrentContext.EvaluationStack.PushT(item.Value()) - return True + if contract is not None: - def Enumerator_Concat(self, engine: ExecutionEngine): - item1 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if item1 is None: - return False + self._contracts.Remove(hash.ToBytes()) - item2 = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if item2 is None: - return False + if contract.HasStorage: - result = ConcatenatedEnumerator(item1, item2) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(result)) + for pair in self._storages.Find(hash.ToBytes()): + self._storages.Remove(pair.Key) + + self.events_to_dispatch.append( + SmartContractEvent(SmartContractEvent.CONTRACT_DESTROY, ContractParameter(ContractParameterType.InteropInterface, contract), + hash, Blockchain.Default().Height + 1, + engine.ScriptContainer.Hash if engine.ScriptContainer else None, + test_mode=engine.testMode)) return True - def Iterator_Create(self, engine: ExecutionEngine): - item = engine.CurrentContext.EvaluationStack.Pop() - if isinstance(item, Map): - iterator = MapWrapper(item) - engine.CurrentContext.EvaluationStack.PushT(StackItem.FromInterface(iterator)) - return True + def Storage_PutEx(self, engine): + logger.error("Storage_PutEx not implemented!") return False - def Iterator_Key(self, engine: ExecutionEngine): - iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if iterator is None: - return False + def Storage_Delete(self, engine: ExecutionEngine): - engine.CurrentContext.EvaluationStack.PushT(iterator.Key()) - return True + context = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - def Iterator_Keys(self, engine: ExecutionEngine): - iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if iterator is None: + if not self.CheckStorageContext(context): return False - wrapper = StackItem.FromInterface(KeysWrapper(iterator)) - engine.CurrentContext.EvaluationStack.PushT(wrapper) - return True - def Iterator_Values(self, engine: ExecutionEngine): - iterator = engine.CurrentContext.EvaluationStack.Pop().GetInterface() - if iterator is None: - return False + key = engine.CurrentContext.EvaluationStack.Pop().GetByteArray() + + storage_key = StorageKey(script_hash=context.ScriptHash, key=key) + + keystr = key + if len(key) == 20: + keystr = Crypto.ToAddress(UInt160(data=key)) + + if type(engine) == ExecutionEngine: + test_mode = False + else: + test_mode = engine.testMode + self.events_to_dispatch.append(SmartContractEvent(SmartContractEvent.STORAGE_DELETE, ContractParameter(ContractParameterType.String, keystr), + context.ScriptHash, Blockchain.Default().Height + 1, + engine.ScriptContainer.Hash if engine.ScriptContainer else None, + test_mode=test_mode)) + + self._storages.Remove(storage_key.ToArray()) - wrapper = StackItem.FromInterface(ValuesWrapper(iterator)) - engine.CurrentContext.EvaluationStack.PushT(wrapper) return True diff --git a/neo/SmartContract/StorageContext.py b/neo/SmartContract/StorageContext.py index 89773bd63..b41d1972e 100644 --- a/neo/SmartContract/StorageContext.py +++ b/neo/SmartContract/StorageContext.py @@ -3,11 +3,7 @@ class StorageContext(InteropMixin): - ScriptHash = None - IsReadOnly = False - def __init__(self, script_hash, read_only=False): - self.ScriptHash = script_hash self.IsReadOnly = read_only diff --git a/neo/SmartContract/tests/test_app_engine.py b/neo/SmartContract/tests/test_app_engine.py index 2e82b0c9d..d0cd6e5be 100644 --- a/neo/SmartContract/tests/test_app_engine.py +++ b/neo/SmartContract/tests/test_app_engine.py @@ -4,24 +4,24 @@ from neo.VM.ExecutionContext import ExecutionContext from neo.SmartContract import TriggerType from mock import Mock, MagicMock +from neo.VM.Script import Script class TestApplicationEngine(NeoTestCase): def setUp(self): self.engine = ApplicationEngine(TriggerType.Application, Mock(), Mock(), Mock(), MagicMock()) - self.engine._Crypto = Mock() def test_get_item_count(self): - econtext1 = ExecutionContext(engine=self.engine) + econtext1 = ExecutionContext(Script(self.engine.Crypto, b''), 0) # 4 items in context 1 - map = Map({'a': 1, 'b': 2, 'c': 3}) + map = Map.FromDictionary({'a': 1, 'b': 2, 'c': 3}) my_int = Integer(BigInteger(1)) econtext1.EvaluationStack.PushT(map) econtext1.EvaluationStack.PushT(my_int) # 3 items in context 2 - econtext2 = ExecutionContext(engine=self.engine) + econtext2 = ExecutionContext(Script(self.engine.Crypto, b''), 0) my_array = Array([my_int, my_int]) econtext2.EvaluationStack.PushT(my_array) econtext2.AltStack.PushT(my_int) diff --git a/neo/SmartContract/tests/test_breakpoints.py b/neo/SmartContract/tests/test_breakpoints.py index 8d6f86915..ed4c04eaa 100644 --- a/neo/SmartContract/tests/test_breakpoints.py +++ b/neo/SmartContract/tests/test_breakpoints.py @@ -10,7 +10,6 @@ class UserWalletTestCase(WalletFixtureTestCase): - wallet_1_script_hash = UInt160(data=b'\x1c\xc9\xc0\\\xef\xff\xe6\xcd\xd7\xb1\x82\x81j\x91R\xec!\x8d.\xc0') wallet_1_addr = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' @@ -38,14 +37,12 @@ def GetWallet1(cls, recreate=False): return cls._wallet1 def test_debug_contract_1(self): - wallet = self.GetWallet1() arguments = ["neo/SmartContract/tests/BreakpointTest.py", "True", "False", "True", "02", "01", "1", ] dbg = VMDebugger -# dbg.end = MagicMock(return_value=None) dbg.start = MagicMock(return_value=None) - tx, result, total_ops, engine = BuildAndRun(arguments, wallet, False, min_fee=Fixed8.FromDecimal(.0004)) + tx, result, total_ops, engine = BuildAndRun(arguments, wallet, False, min_fee=Fixed8.FromDecimal(.0004), enable_debugger=True) debugger = engine._vm_debugger context = debugger.get_context() @@ -62,9 +59,8 @@ def test_debug_contract_2(self): arguments = ["neo/SmartContract/tests/BreakpointTest.py", "True", "False", "True", "02", "01", "4", ] dbg = VMDebugger - # dbg.end = MagicMock(return_value=None) dbg.start = MagicMock(return_value=None) - tx, result, total_ops, engine = BuildAndRun(arguments, wallet, False, min_fee=Fixed8.FromDecimal(.0004)) + tx, result, total_ops, engine = BuildAndRun(arguments, wallet, False, min_fee=Fixed8.FromDecimal(.0004), enable_debugger=True) debugger = engine._vm_debugger context = debugger.get_context() diff --git a/neo/SmartContract/tests/test_payable.py b/neo/SmartContract/tests/test_payable.py index 9af523400..aefc2644d 100644 --- a/neo/SmartContract/tests/test_payable.py +++ b/neo/SmartContract/tests/test_payable.py @@ -6,7 +6,6 @@ class SmartContractPayable(WalletFixtureTestCase): - _wallet1 = None _path = "neo/SmartContract/tests/PayableTest.avm" @@ -21,7 +20,7 @@ def test_is_payable(self): Result [{'type': 'Boolean', 'value': True}] """ - wallet = self.GetWallet1() + wallet = self.GetWallet1(recreate=True) arguments = [self._path, "False", "False", "True", "07", "01", "payable"] diff --git a/neo/SmartContract/tests/test_vm_error_output.py b/neo/SmartContract/tests/test_vm_error_output.py index 5e26bd088..698f725c1 100644 --- a/neo/SmartContract/tests/test_vm_error_output.py +++ b/neo/SmartContract/tests/test_vm_error_output.py @@ -46,9 +46,12 @@ def test_invalid_type_indexing(self): def test_invalid_appcall(self): with self.assertLogHandler('vm', DEBUG) as log_context: tx, results, total_ops, engine = TestBuild(self.script, [4, ['my_arg0']], self.GetWallet1(), '0210', '07', dynamic=True) - self.assertTrue(len(log_context.output) > 1) - log_msg = log_context.output[1] - self.assertTrue("Trying to call an unknown contract" in log_msg) + found = False + for log_msg in log_context.output: + if "Trying to call an unknown contract" in log_msg: + found = True + break + self.assertTrue(found) def test_no_logging_if_loglevel_not_debug(self): with self.assertLogHandler('vm', INFO) as log_context: diff --git a/neo/Utils/BlockchainFixtureTestCase.py b/neo/Utils/BlockchainFixtureTestCase.py index c5b8b5d4a..558a888c6 100644 --- a/neo/Utils/BlockchainFixtureTestCase.py +++ b/neo/Utils/BlockchainFixtureTestCase.py @@ -3,13 +3,14 @@ import shutil import os import neo +import asyncio from neo.Utils.NeoTestCase import NeoTestCase from neo.Implementations.Blockchains.LevelDB.TestLevelDBBlockchain import TestLevelDBBlockchain from neo.Core.Blockchain import Blockchain from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB from neo.Settings import settings from neo.logging import log_manager -from neo.Network.NodeLeader import NodeLeader +from neo.Network.nodemanager import NodeManager logger = log_manager.getLogger() @@ -26,6 +27,9 @@ class BlockchainFixtureTestCase(NeoTestCase): wallets_folder = os.path.dirname(neo.__file__) + '/Utils/fixtures/' + def __init__(self, *args, **kwargs): + super(BlockchainFixtureTestCase, self).__init__(*args, **kwargs) + @classmethod def leveldb_testpath(cls): return 'Override Me!' @@ -37,8 +41,14 @@ def setUpClass(cls): super(BlockchainFixtureTestCase, cls).setUpClass() - NodeLeader.Instance().Reset() - NodeLeader.Instance().Setup() + # for some reason during testing asyncio.get_event_loop() fails and does not create a new one if needed. This is the workaround + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + nodemgr = NodeManager() + nodemgr.reset_for_test() # setup Blockchain DB if not os.path.exists(cls.FIXTURE_FILENAME): diff --git a/neo/Utils/NeoTestCase.py b/neo/Utils/NeoTestCase.py index 42dcbdf5d..c711f80fe 100644 --- a/neo/Utils/NeoTestCase.py +++ b/neo/Utils/NeoTestCase.py @@ -2,7 +2,9 @@ from unittest.case import _BaseTestCaseContext import logging import collections +import asyncio from neo.logging import log_manager +from mock import MagicMock class _CapturingHandler(logging.Handler): @@ -57,7 +59,16 @@ def __exit__(self, exc_type, exc_value, tb): self._logger.handlers[0] = self.stdio_handler +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + class NeoTestCase(TestCase): + + def __init__(self, *args, **kwargs): + super(NeoTestCase, self).__init__(*args, **kwargs) + def assertLogHandler(self, component_name: str, level: int): """ This method must be used as a context manager, and will yield @@ -75,3 +86,11 @@ def assertLogHandler(self, component_name: str, level: int): context manager """ return _AssertLogHandlerContext(self, component_name, level) + + def async_return(self, result): + f = asyncio.Future() + f.set_result(result) + return f + + def new_async_mock(self): + return AsyncMock() diff --git a/neo/Utils/VMJSONTestCase.py b/neo/Utils/VMJSONTestCase.py index bf32b8d27..d3a2247c6 100644 --- a/neo/Utils/VMJSONTestCase.py +++ b/neo/Utils/VMJSONTestCase.py @@ -10,7 +10,7 @@ class VMJSONTestCase(NeoTestCase): - NEO_VM_REPO_URL = "https://github.com/neo-project/neo-vm/tarball/c45330eee5a0ef47a03a7dad212318a7acaf01b5" + NEO_VM_REPO_URL = "https://github.com/neo-project/neo-vm/tarball/cd5c3d0460bd1d4acce34be91c38a2ccfca8050f" SOURCE_FILENAME = os.path.join(settings.DATA_DIR_PATH, 'vm-tests/neo-vm.tar.gz') @classmethod diff --git a/neo/VM/Debugger.py b/neo/VM/Debugger.py new file mode 100644 index 000000000..2240ac4ab --- /dev/null +++ b/neo/VM/Debugger.py @@ -0,0 +1,112 @@ +import datetime +from neo.VM.ExecutionEngine import ExecutionEngine +from neo.VM import VMState +from neo.Settings import settings +from neo.Prompt.vm_debugger import VMDebugger +from neo.logging import log_manager + +logger = log_manager.getLogger('vm') + + +class Debugger: + def __init__(self, engine: ExecutionEngine): + self.engine = engine + self.engine.debugger = self + self._breakpoints = dict() + + def Execute(self): + self.engine._VMState &= ~VMState.BREAK + + def loop_execute_next(): + while self.engine._VMState & VMState.HALT == 0 \ + and self.engine._VMState & VMState.FAULT == 0 \ + and self.engine._VMState & VMState.BREAK == 0: + self.ExecuteAndCheckBreakpoint() + + if settings.log_vm_instructions: + with open(self.engine.log_file_name, 'w') as self.log_file: + self.engine.write_log(str(datetime.datetime.now())) + loop_execute_next() + else: + loop_execute_next() + + return not self.engine._VMState & VMState.FAULT > 0 + + def ExecuteAndCheckBreakpoint(self): + self.engine.ExecuteNext() + + if self.engine._VMState == VMState.NONE and self.engine._InvocationStack.Count > 0: + script_hash = self.engine.CurrentContext.ScriptHash() + bps = self._breakpoints.get(script_hash, None) + if bps is not None: + if self.engine.CurrentContext.InstructionPointer in bps: + self.engine._VMState = VMState.BREAK + self.engine._vm_debugger = VMDebugger(self.engine) + self.engine._vm_debugger.start() + + def AddBreakPoint(self, script_hash, position): + ctx_breakpoints = self._breakpoints.get(script_hash, None) + if ctx_breakpoints is None: + self._breakpoints[script_hash] = set([position]) + else: + # add by reference + ctx_breakpoints.add(position) + + def RemoveBreakPoint(self, script_hash, position): + # test if any breakpoints exist for script hash + ctx = self._breakpoints.get(script_hash, None) + if ctx is None: + return False + + # remove if specific bp exists + if position in ctx: + ctx.remove(position) + else: + return False + + # clear set from breakpoints list if no more bp's exist for it + if len(ctx) == 0: + del self._breakpoints[script_hash] + + return True + + def StepInto(self): + + if self.engine._VMState & VMState.HALT > 0 or self.engine._VMState & VMState.FAULT > 0: + logger.debug("stopping because vm state is %s " % self.engine._VMState) + return + self.engine.ExecuteNext() + if self.engine._VMState == VMState.NONE: + self.engine._VMState = VMState.BREAK + + def StepOut(self): + self.engine._VMState &= ~VMState.BREAK + c = self.engine.InvocationStack.Count + + while self.engine._VMState & VMState.HALT == 0 \ + and self.engine._VMState & VMState.FAULT == 0 \ + and self.engine._VMState & VMState.BREAK == 0 \ + and self.engine.InvocationStack.Count >= c: + self.ExecuteAndCheckBreakpoint() + + if self.engine._VMState == VMState.NONE: + self.engine._VMState = VMState.BREAK + + def StepOver(self): + if self.engine._VMState & VMState.HALT > 0 or self.engine._VMState & VMState.FAULT > 0: + return + + self.engine._VMState &= ~VMState.BREAK + c = self.engine.InvocationStack.Count + while True: + self.ExecuteAndCheckBreakpoint() + + go_on = self.engine._VMState & VMState.HALT == 0 \ + and self.engine._VMState & VMState.FAULT == 0 \ + and self.engine._VMState & VMState.BREAK == 0 \ + and self.engine.InvocationStack.Count > c # noqa + if not go_on: + break + + if self.engine._VMState == VMState.NONE: + self.engine._VMState = VMState.BREAK diff --git a/neo/VM/ExecutionContext.py b/neo/VM/ExecutionContext.py index cebee8e33..2ad8a730d 100644 --- a/neo/VM/ExecutionContext.py +++ b/neo/VM/ExecutionContext.py @@ -2,19 +2,23 @@ from neo.Core.IO.BinaryReader import BinaryReader from neo.VM.RandomAccessStack import RandomAccessStack from neo.VM.OpCode import RET +from neo.VM.Instruction import Instruction +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from neo.VM.Script import Script -class ExecutionContext: - Script = None - - __OpReader = None - - __mstream = None - _RVCount = None +class ExecutionContext: - _EvaluationStack = None - _AltStack = None + def __init__(self, script: 'Script', rvcount: int): + self.instructions = {} + self._EvaluationStack = RandomAccessStack(name='Evaluation') + self._AltStack = RandomAccessStack(name='Alt') + self.InstructionPointer = 0 + self.Script = script + self._RVCount = rvcount + self._script_hash = None @property def EvaluationStack(self): @@ -25,43 +29,30 @@ def AltStack(self): return self._AltStack @property - def OpReader(self): - return self.__OpReader + def CurrentInstruction(self): + return self.GetInstruction(self.InstructionPointer) @property - def InstructionPointer(self): - return self.__OpReader.stream.tell() - - @InstructionPointer.setter - def InstructionPointer(self, value): - self.__OpReader.stream.seek(value) + def NextInstruction(self): + return self.GetInstruction(self.InstructionPointer + self.CurrentInstruction.Size) - def SetInstructionPointer(self, value): - self.__OpReader.stream.seek(value) + def ScriptHash(self): + return self.Script.ScriptHash - @property - def NextInstruction(self): - index = self.__OpReader.stream.tell() - if index >= len(self.Script): - return RET - else: - return self.Script[index].to_bytes(1, 'little') + def GetInstruction(self, ip) -> Instruction: + if ip >= self.Script.Length: + return Instruction.RET() + instruction = self.instructions.get(ip, None) - _script_hash = None + if instruction is None: + instruction = Instruction.FromScriptAndIP(self.Script, ip) + self.instructions.update({ip: instruction}) - def ScriptHash(self): - if self._script_hash is None: - self._script_hash = self.crypto.Hash160(self.Script) - return self._script_hash + return instruction - def __init__(self, engine=None, script=None, rvcount=0): - self.Script = script - self.__mstream = StreamManager.GetStream(self.Script) - self.__OpReader = BinaryReader(self.__mstream) - self._EvaluationStack = RandomAccessStack(name='Evaluation') - self._AltStack = RandomAccessStack(name='Alt') - self._RVCount = rvcount - self.crypto = engine.Crypto + def MoveNext(self): + self.InstructionPointer += self.CurrentInstruction.Size + return self.InstructionPointer < self.Script.Length def Dispose(self): self.__OpReader = None diff --git a/neo/VM/ExecutionEngine.py b/neo/VM/ExecutionEngine.py index ce83c6566..2637fea3f 100644 --- a/neo/VM/ExecutionEngine.py +++ b/neo/VM/ExecutionEngine.py @@ -1,5 +1,6 @@ import hashlib import datetime +import traceback from neo.VM.OpCode import * from neo.VM.RandomAccessStack import RandomAccessStack @@ -9,11 +10,12 @@ from neo.Core.UInt160 import UInt160 from neo.Settings import settings from neo.VM.VMFault import VMFault -from neo.Prompt.vm_debugger import VMDebugger from logging import DEBUG as LOGGING_LEVEL_DEBUG from neo.logging import log_manager from typing import TYPE_CHECKING from collections import deque +from neo.VM.OpCode import ToName +from neo.VM.Script import Script if TYPE_CHECKING: from neo.VM.InteropService import BigInteger @@ -37,7 +39,7 @@ class ExecutionEngine: maxStackSize = 2048 maxInvocationStackSize = 1024 - def __init__(self, container=None, crypto=None, table=None, service=None, exit_on_error=False): + def __init__(self, container=None, crypto=None, table=None, service=None, exit_on_error=True): self._VMState = VMState.BREAK self._ScriptContainer = container self._Crypto = crypto @@ -50,9 +52,9 @@ def __init__(self, container=None, crypto=None, table=None, service=None, exit_o self.ops_processed = 0 self._debug_map = None self._is_write_log = settings.log_vm_instructions - self._breakpoints = dict() self._is_stackitem_count_strict = True self._stackitem_count = 0 + self._EntryScriptHash = None def CheckArraySize(self, length: int) -> bool: return length <= self.maxArraySize @@ -99,11 +101,11 @@ def GetItemCount(self, items_list): # list of StackItems items = deque(items_list) while items: stackitem = items.pop() - if isinstance(stackitem, Map): + if stackitem.IsTypeMap: items.extend(stackitem.Values) continue - if isinstance(stackitem, Array): + if stackitem.IsTypeArray: items.extend(stackitem.GetArray()) continue count += 1 @@ -150,10 +152,6 @@ def CallingContext(self): return self.InvocationStack.Peek(1) return None - @property - def ExitOnError(self): - return self._exit_on_error - @property def EntryContext(self): return self.InvocationStack.Peek(self.InvocationStack.Count - 1) @@ -162,22 +160,13 @@ def EntryContext(self): def ExecutedScriptHashes(self): return self._ExecutedScriptHashes - def AddBreakPoint(self, script_hash, position): - ctx_breakpoints = self._breakpoints.get(script_hash, None) - if ctx_breakpoints is None: - self._breakpoints[script_hash] = set([position]) - else: - # add by reference - ctx_breakpoints.add(position) - def LoadDebugInfoForScriptHash(self, debug_map, script_hash): if debug_map and script_hash: self._debug_map = debug_map self._debug_map['script_hash'] = script_hash def Dispose(self): - while self._InvocationStack.Count > 0: - self._InvocationStack.Pop().Dispose() + self.InvocationStack.Clear() def Execute(self): self._VMState &= ~VMState.BREAK @@ -193,55 +182,37 @@ def loop_stepinto(): else: loop_stepinto() - def ExecuteOp(self, opcode, context: ExecutionContext): + return not self._VMState & VMState.FAULT > 0 + + def ExecuteInstruction(self): + context = self.CurrentContext + instruction = context.CurrentInstruction + opcode = instruction.OpCode estack = context._EvaluationStack istack = self._InvocationStack astack = context._AltStack - if opcode >= PUSHBYTES1 and opcode <= PUSHBYTES75: - bytestoread = context.OpReader.SafeReadBytes(int.from_bytes(opcode, 'little')) - estack.PushT(bytestoread) + if opcode >= PUSHBYTES1 and opcode <= PUSHDATA4: + if not self.CheckMaxItemSize(len(instruction.Operand)): + return False + estack.PushT(instruction.Operand) if not self.CheckStackSize(True): return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) else: # push values - pushops = [PUSHM1, PUSH1, PUSH2, PUSH3, PUSH4, PUSH5, PUSH6, PUSH7, PUSH8, - PUSH9, PUSH10, PUSH11, PUSH12, PUSH13, PUSH14, PUSH15, PUSH16] - - if opcode == PUSH0: - estack.PushT(bytearray(0)) - if not self.CheckStackSize(True): - return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) - - elif opcode == PUSHDATA1: - lenngth = ord(context.OpReader.ReadByte()) - estack.PushT(bytearray(context.OpReader.SafeReadBytes(lenngth))) - - if not self.CheckStackSize(True): - return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) - - elif opcode == PUSHDATA2: - estack.PushT(context.OpReader.SafeReadBytes(context.OpReader.ReadUInt16())) + if opcode in [PUSHM1, PUSH1, PUSH2, PUSH3, PUSH4, PUSH5, PUSH6, PUSH7, PUSH8, + PUSH9, PUSH10, PUSH11, PUSH12, PUSH13, PUSH14, PUSH15, PUSH16]: - if not self.CheckStackSize(True): - return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) - - elif opcode == PUSHDATA4: - length = context.OpReader.ReadUInt32() - if not self.CheckMaxItemSize(length): - return self.VM_FAULT_and_report(VMFault.PUSHDATA_EXCEED_MAXITEMSIZE) - - estack.PushT(context.OpReader.SafeReadBytes(length)) + topush = int.from_bytes(opcode, 'little') - int.from_bytes(PUSH1, 'little') + 1 + estack.PushT(topush) if not self.CheckStackSize(True): return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) - elif opcode in pushops: - topush = int.from_bytes(opcode, 'little') - int.from_bytes(PUSH1, 'little') + 1 - estack.PushT(topush) - + elif opcode == PUSH0: + estack.PushT(bytearray(0)) if not self.CheckStackSize(True): return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) @@ -249,10 +220,9 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == NOP: pass elif opcode in [JMP, JMPIF, JMPIFNOT]: - offset_b = context.OpReader.ReadInt16() - offset = context.InstructionPointer + offset_b - 3 + offset = context.InstructionPointer + instruction.TokenI16 - if offset < 0 or offset > len(context.Script): + if offset < 0 or offset > context.Script.Length: return self.VM_FAULT_and_report(VMFault.INVALID_JUMP) fValue = True @@ -262,19 +232,21 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if opcode == JMPIFNOT: fValue = not fValue if fValue: - context.SetInstructionPointer(offset) + context.InstructionPointer = offset + else: + context.InstructionPointer += 3 + return True elif opcode == CALL: if not self.CheckMaxInvocationStack(): return self.VM_FAULT_and_report(VMFault.CALL_EXCEED_MAX_INVOCATIONSTACK_SIZE) - context_call = self.LoadScript(context.Script) + context_call = self._LoadScriptInternal(context.Script) + context_call.InstructionPointer = context.InstructionPointer + instruction.TokenI16 + if context_call.InstructionPointer < 0 or context_call.InstructionPointer > context_call.Script.Length: + return False context.EvaluationStack.CopyTo(context_call.EvaluationStack) - context_call.InstructionPointer = context.InstructionPointer context.EvaluationStack.Clear() - context.InstructionPointer += 2 - - self.ExecuteOp(JMP, context_call) elif opcode == RET: context_pop: ExecutionContext = istack.Pop() @@ -285,7 +257,6 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if rvcount > 0: if context_pop.EvaluationStack.Count < rvcount: - context_pop.Dispose() return self.VM_FAULT_and_report(VMFault.UNKNOWN1) if istack.Count == 0: @@ -301,7 +272,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if istack.Count == 0: self._VMState = VMState.HALT - context_pop.Dispose() + return True elif opcode == APPCALL or opcode == TAILCALL: if self._Table is None: @@ -310,7 +281,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if opcode == APPCALL and not self.CheckMaxInvocationStack(): return self.VM_FAULT_and_report(VMFault.APPCALL_EXCEED_MAX_INVOCATIONSTACK_SIZE) - script_hash = context.OpReader.SafeReadBytes(20) + script_hash = instruction.Operand is_normal_call = False for b in script_hash: @@ -320,31 +291,27 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if not is_normal_call: script_hash = estack.Pop().GetByteArray() - script = self._Table.GetScript(UInt160(data=script_hash).ToBytes()) - - if script is None: + context_new = self._LoadScriptByHash(script_hash) + if context_new is None: return self.VM_FAULT_and_report(VMFault.INVALID_CONTRACT, script_hash) - context_new = self.LoadScript(script) estack.CopyTo(context_new.EvaluationStack) if opcode == TAILCALL: - istack.Remove(1).Dispose() + istack.Remove(1) else: estack.Clear() self.CheckStackSize(False, 0) elif opcode == SYSCALL: - try: - call = context.OpReader.ReadVarBytes(252).decode('ascii') - except ValueError: - # probably failed to read enough bytes - return self.VM_FAULT_and_report(VMFault.SYSCALL_INSUFFICIENT_DATA) + if len(instruction.Operand) > 252: + return False + call = instruction.Operand.decode('ascii') self.write_log(call) if not self._Service.Invoke(call, self): - return self.VM_FAULT_and_report(VMFault.SYSCALL_ERROR, call) + return self.VM_FAULT_and_report(VMFault.SYSCALL_ERROR, instruction.Operand) if not self.CheckStackSize(False, int_MaxValue): return self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) @@ -465,13 +432,6 @@ def ExecuteOp(self, opcode, context: ExecutionContext): x = estack.Pop().GetByteArray() - len_x = len(x) - if index > len_x: - return self.VM_FAULT_and_report(VMFault.SUBSTR_INVALID_INDEX) - - if index + count > len_x: - count = len_x - index - estack.PushT(x[index:count + index]) self.CheckStackSize(True, -2) @@ -507,19 +467,12 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == INVERT: x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(~x) elif opcode == AND: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 & x2) self.CheckStackSize(True, -1) @@ -527,12 +480,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == OR: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 | x2) self.CheckStackSize(True, -1) @@ -540,12 +488,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == XOR: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 ^ x2) self.CheckStackSize(True, -1) @@ -576,24 +519,18 @@ def ExecuteOp(self, opcode, context: ExecutionContext): # Make sure to implement sign for big integer x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x.Sign) elif opcode == NEGATE: x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(-x) elif opcode == ABS: x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(abs(x)) @@ -606,8 +543,6 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == NZ: x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x is not 0) @@ -742,12 +677,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == LT: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 < x2) self.CheckStackSize(True, -1) @@ -755,12 +685,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == GT: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 > x2) self.CheckStackSize(True, -1) @@ -768,12 +693,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == LTE: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(x1 <= x2) self.CheckStackSize(True, -1) @@ -789,12 +709,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == MIN: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(min(x1, x2)) self.CheckStackSize(True, -1) @@ -802,12 +717,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == MAX: x2 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x2): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x1 = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x1): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(max(x1, x2)) self.CheckStackSize(True, -1) @@ -815,16 +725,8 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == WITHIN: b = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(b): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - a = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(a): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) - x = estack.Pop().GetBigInteger() - if not self.CheckBigInteger(x): - return self.VM_FAULT_and_report(VMFault.BIGINTEGER_EXCEED_LIMIT) estack.PushT(a <= x and x < b) self.CheckStackSize(True, -2) @@ -1007,7 +909,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): to_pick = items[index] estack.PushT(to_pick) - if not self.CheckStackSize(False, int_MaxValue): + if not self.CheckStackSize(False, -1): self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) elif isinstance(collection, Map): @@ -1016,7 +918,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): if success: estack.PushT(value) - if not self.CheckStackSize(False, int_MaxValue): + if not self.CheckStackSize(False, -1): self.VM_FAULT_and_report(VMFault.INVALID_STACKSIZE) else: @@ -1028,7 +930,7 @@ def ExecuteOp(self, opcode, context: ExecutionContext): self.VM_FAULT_and_report(VMFault.PICKITEM_INVALID_INDEX, index, len(byte_array)) return estack.PushT(byte_array[index]) - self.CheckStackSize(False, -1) + self.CheckStackSize(True, -1) elif opcode == SETITEM: value = estack.Pop() @@ -1227,27 +1129,29 @@ def ExecuteOp(self, opcode, context: ExecutionContext): elif opcode == CALL_I: if not self.CheckMaxInvocationStack(): return self.VM_FAULT_and_report(VMFault.CALL__I_EXCEED_MAX_INVOCATIONSTACK_SIZE) - rvcount = ord(context.OpReader.ReadByte()) - pcount = ord(context.OpReader.ReadByte()) + rvcount = instruction.Operand[0] + pcount = instruction.Operand[1] if estack.Count < pcount: return self.VM_FAULT_and_report(VMFault.UNKNOWN_STACKISOLATION) - context_call = self.LoadScript(context.Script, rvcount) + context_call = self._LoadScriptInternal(context.Script, rvcount) + context_call.InstructionPointer = context.InstructionPointer + instruction.TokenI16_1 + 2 + + if context_call.InstructionPointer < 0 or context_call.InstructionPointer > context_call.Script.Length: + return False + estack.CopyTo(context_call.EvaluationStack, pcount) - context_call.InstructionPointer = context.InstructionPointer for i in range(0, pcount, 1): estack.Pop() - context.InstructionPointer += 2 - self.ExecuteOp(JMP, context_call) elif opcode in [CALL_E, CALL_ED, CALL_ET, CALL_EDT]: if self._Table is None: return self.VM_FAULT_and_report(VMFault.UNKNOWN_STACKISOLATION2) - rvcount = ord(context.OpReader.ReadByte()) - pcount = ord(context.OpReader.ReadByte()) + rvcount = instruction.Operand[0] + pcount = instruction.Operand[1] if estack.Count < pcount: return self.VM_FAULT_and_report(VMFault.UNKNOWN_STACKISOLATION) @@ -1263,19 +1167,16 @@ def ExecuteOp(self, opcode, context: ExecutionContext): script_hash = estack.Pop().GetByteArray() self.CheckStackSize(True, -1) else: - script_hash = context.OpReader.SafeReadBytes(20) - - script = self._Table.GetScript(UInt160(data=script_hash).ToBytes()) + script_hash = instruction.ReadBytes(2, 20) - if script is None: - logger.debug("Could not find script from script table: %s " % script_hash) + context_new = self._LoadScriptByHash(script_hash, rvcount) + if context_new is None: return self.VM_FAULT_and_report(VMFault.INVALID_CONTRACT, script_hash) - context_new = self.LoadScript(script, rvcount) estack.CopyTo(context_new.EvaluationStack, pcount) if opcode in [CALL_ET, CALL_EDT]: - istack.Remove(1).Dispose() + istack.Remove(1) else: for i in range(0, pcount, 1): estack.Pop() @@ -1290,90 +1191,76 @@ def ExecuteOp(self, opcode, context: ExecutionContext): else: return self.VM_FAULT_and_report(VMFault.UNKNOWN_OPCODE, opcode) + context.MoveNext() + return True + + def LoadScript(self, script: bytearray, rvcount: int = -1) -> ExecutionContext: + # "raw" bytes + new_script = Script(self.Crypto, script) - def LoadScript(self, script, rvcount=-1) -> ExecutionContext: + return self._LoadScriptInternal(new_script, rvcount) - context = ExecutionContext(self, script, rvcount) + def _LoadScriptInternal(self, script: Script, rvcount=-1): + context = ExecutionContext(script, rvcount) self._InvocationStack.PushT(context) self._ExecutedScriptHashes.append(context.ScriptHash()) # add break points for current script if available script_hash = context.ScriptHash() if self._debug_map and script_hash == self._debug_map['script_hash']: - self._breakpoints[script_hash] = set(self._debug_map['breakpoints']) + if self.debugger: + self.debugger._breakpoints[script_hash] = set(self._debug_map['breakpoints']) return context - def RemoveBreakPoint(self, script_hash, position): - # test if any breakpoints exist for script hash - ctx = self._breakpoints.get(script_hash, None) - if ctx is None: - return False + def _LoadScriptByHash(self, script_hash: bytearray, rvcount=-1): - # remove if specific bp exists - if position in ctx: - ctx.remove(position) - else: - return False + if self._Table is None: + return None + script = self._Table.GetScript(UInt160(data=script_hash).ToBytes()) + if script is None: + return None + return self._LoadScriptInternal(Script.FromHash(script_hash, script), rvcount) - # clear set from breakpoints list if no more bp's exist for it - if len(ctx) == 0: - del self._breakpoints[script_hash] + def PreExecuteInstruction(self): + # allow overriding + return True + def PostExecuteInstruction(self): + # allow overriding return True def ExecuteNext(self): if self._InvocationStack.Count == 0: self._VMState = VMState.HALT else: - op = None - - if self.CurrentContext.InstructionPointer >= len(self.CurrentContext.Script): - op = RET - else: - op = self.CurrentContext.OpReader.ReadByte() - self.ops_processed += 1 try: + instruction = self.CurrentContext.CurrentInstruction if self._is_write_log: - self.write_log("{} {}".format(self.ops_processed, ToName(op))) - self.ExecuteOp(op, self.CurrentContext) + self.write_log("{} {} {}".format(self.ops_processed, instruction.InstructionName, self.CurrentContext.InstructionPointer)) + + if not self.PreExecuteInstruction(): + self._VMState = VMState.FAULT + if not self.ExecuteInstruction(): + self._VMState = VMState.FAULT + if not self.PostExecuteInstruction(): + self._VMState = VMState.FAULT except Exception as e: - error_msg = "COULD NOT EXECUTE OP (%s): %s %s %s" % (self.ops_processed, e, op, ToName(op)) + + error_msg = f"COULD NOT EXECUTE OP ({self.ops_processed}): {e}" + # traceback.print_exc() self.write_log(error_msg) if self._exit_on_error: self._VMState = VMState.FAULT - # This is how C# does it now (2019-03-25) according to - # https://github.com/neo-project/neo-vm/blob/0e5cb856021e288915623fe994084d97a0417373/src/neo-vm/ExecutionEngine.cs#L234 - # but that doesn't help us set a breakpoint, so we use the old logic - # TODO: revisit at some point - # if self._VMState == VMState.NONE and self._InvocationStack.Count > 0: - if self._VMState & VMState.FAULT == 0 and self._InvocationStack.Count > 0: - script_hash = self.CurrentContext.ScriptHash() - bps = self._breakpoints.get(self.CurrentContext.ScriptHash(), None) - if bps is not None: - if self.CurrentContext.InstructionPointer in bps: - self._VMState = VMState.BREAK - self._vm_debugger = VMDebugger(self) - self._vm_debugger.start() - - def StepInto(self): - - if self._VMState & VMState.HALT > 0 or self._VMState & VMState.FAULT > 0: - logger.debug("stopping because vm state is %s " % self._VMState) - return - self.ExecuteNext() - if self._VMState == VMState.NONE: - self._VMState = VMState.BREAK - def VM_FAULT_and_report(self, id, *args): self._VMState = VMState.FAULT if not logger.hasHandlers() or logger.handlers[0].level != LOGGING_LEVEL_DEBUG: - return + return False # if settings.log_level != LOGGING_LEVEL_DEBUG: # return @@ -1451,4 +1338,4 @@ def VM_FAULT_and_report(self, id, *args): else: logger.debug("({}) {}".format(self.ops_processed, error_msg)) - return + return False diff --git a/neo/VM/Instruction.py b/neo/VM/Instruction.py new file mode 100644 index 000000000..a6cd21315 --- /dev/null +++ b/neo/VM/Instruction.py @@ -0,0 +1,129 @@ +from neo.VM import OpCode +from typing import TYPE_CHECKING +from functools import lru_cache +import binascii + +if TYPE_CHECKING: + from neo.VM.Script import Script + +_OperandSizeTable = {} +start = int.from_bytes(OpCode.PUSHBYTES1, 'little') +end = int.from_bytes(OpCode.PUSHBYTES75, 'little') + 1 +for op_num in range(start, end): + _OperandSizeTable[int.to_bytes(op_num, 1, 'little')] = op_num +_OperandSizeTable[OpCode.JMP] = 2 +_OperandSizeTable[OpCode.JMPIF] = 2 +_OperandSizeTable[OpCode.JMPIFNOT] = 2 +_OperandSizeTable[OpCode.CALL] = 2 +_OperandSizeTable[OpCode.APPCALL] = 20 +_OperandSizeTable[OpCode.TAILCALL] = 20 +_OperandSizeTable[OpCode.CALL_I] = 4 +_OperandSizeTable[OpCode.CALL_E] = 22 +_OperandSizeTable[OpCode.CALL_ED] = 2 +_OperandSizeTable[OpCode.CALL_ET] = 22 +_OperandSizeTable[OpCode.CALL_EDT] = 2 + + +class Instruction: + + @classmethod + def RET(cls): + return cls(0x66) + + def __init__(self, opcode: int): + self.OpCode = int.to_bytes(opcode, 1, 'little') + self.Operand = bytearray() + self._OperandSizeTable = _OperandSizeTable + self._OperandSizePrefixTable = {} + + @property + @lru_cache() + def OperandSizePrefixTable(self): + + self._OperandSizePrefixTable[OpCode.PUSHDATA1] = 1 + self._OperandSizePrefixTable[OpCode.PUSHDATA2] = 2 + self._OperandSizePrefixTable[OpCode.PUSHDATA4] = 4 + self._OperandSizePrefixTable[OpCode.SYSCALL] = 1 + return self._OperandSizePrefixTable + + @property + @lru_cache() + def OperandSizeTable(self): + return self._OperandSizeTable + + @classmethod + def FromScriptAndIP(clss, script: 'Script', ip: int): + ins = clss(script[ip]) + ip += 1 + operand_size = ins.OperandSizePrefixTable.get(ins.OpCode, 0) + + if operand_size == 0: + operand_size = ins.OperandSizeTable.get(ins.OpCode, 0) + elif operand_size == 1: + ip, operand_size = ins.ReadByte(script, ip) + elif operand_size == 2: + ip, operand_size = ins.ReadUint16(script, ip) + elif operand_size == 4: + ip, operand_size = ins.ReadInt32(script, ip) + + if (operand_size > 0): + ins.Operand = ins.ReadExactBytes(script, ip, operand_size) + return ins + + @property + def InstructionName(self): + return OpCode.ToName(self.OpCode) + + @property + def Size(self): + prefixSize = self.OperandSizePrefixTable.get(self.OpCode, 0) + + if prefixSize > 0: + return 1 + prefixSize + len(self.Operand) + else: + return 1 + self.OperandSizeTable.get(self.OpCode, 0) + + @property + def TokenI16(self): + return int.from_bytes(self.Operand, 'little', signed=True) + + @property + def TokenI16_1(self): + return int.from_bytes(self.Operand[2:], 'little', signed=True) + + @property + def TokenU32(self): + return int.from_bytes(self.Operand, 'little', signed=False) + + @property + def TokenString(self): + return self.Operand.decode('ascii') + + def ReadByte(self, script, ip): + next_byte_index = ip + 1 + if next_byte_index > script.Length: + raise ValueError + return next_byte_index, script[ip] + + def ReadUint16(self, script, ip): + if ip + 2 > script.Length: + raise ValueError + return ip + 2, int.from_bytes(script[ip:ip + 2], 'little', signed=False) + + def ReadInt32(self, script, ip): + if ip + 4 > script.Length: + raise ValueError + return ip + 4, int.from_bytes(script[ip:ip + 4], 'little', signed=True) + + def ReadBytes(self, offset, count): + if offset + count > len(self.Operand): + raise Exception + return self.Operand[offset:offset + count] + + def ReadExactBytes(self, script, ip, count): + if ip + count > script.Length: + raise ValueError + return script[ip:ip + count] + + def __str__(self): + return self.InstructionName diff --git a/neo/VM/InteropService.py b/neo/VM/InteropService.py index 91004a212..64a9e664c 100644 --- a/neo/VM/InteropService.py +++ b/neo/VM/InteropService.py @@ -10,8 +10,10 @@ class CollectionMixin: - IsSynchronized = False - SyncRoot = None + + def __init__(self): + self.IsSynchronized = False + self.SyncRoot = None @property def Count(self): @@ -29,6 +31,14 @@ def CopyTo(self, array, index): class StackItem(EquatableMixin): + @property + def IsTypeMap(self): + return False + + @property + def IsTypeArray(self): + return False + @property def IsStruct(self): return False @@ -143,7 +153,10 @@ def New(value): class Array(StackItem, CollectionMixin): - _array = None # a list of stack items + + @property + def IsTypeArray(self): + return True @property def Count(self): @@ -229,8 +242,6 @@ class Boolean(StackItem): TRUE = bytearray([1]) FALSE = bytearray() # restore once https://github.com/neo-project/neo-vm/pull/132 is approved - _value = None - def __init__(self, value): self._value = value @@ -258,6 +269,9 @@ def GetBoolean(self): def GetByteArray(self): return self.TRUE if self._value else self.FALSE + def GetByteLength(self): + return len(self.GetByteArray()) + def Serialize(self, writer): writer.WriteByte(StackItemType.Boolean) writer.WriteByte(self.GetBigInteger()) @@ -267,7 +281,6 @@ def __str__(self): class ByteArray(StackItem): - _value = None def __init__(self, value): self._value = value @@ -326,7 +339,6 @@ def __str__(self): class Integer(StackItem): - _value = None def __init__(self, value): if type(value) is not BigInteger: @@ -366,7 +378,6 @@ def __str__(self): class InteropInterface(StackItem): - _object = None def __init__(self, value): self._object = value @@ -446,7 +457,6 @@ def __str__(self): class Map(StackItem, CollectionMixin): - _dict = None def __init__(self, dict=None): if dict: @@ -454,6 +464,10 @@ def __init__(self, dict=None): else: self._dict = {} + @property + def IsTypeMap(self): + return True + @property def Keys(self): return list(self._dict.keys()) @@ -535,9 +549,15 @@ def GetString(self): def GetByteArray(self): raise Exception("Not supported- Cant get byte array for item %s %s " % (type(self), self._dict)) + @classmethod + def FromDictionary(cls, dictionary: dict): + data = {} + for k, v in dictionary.items(): + data[StackItem.New(k)] = StackItem.New(v) + return cls(data) + class InteropService: - _dictionary = {} def __init__(self): self._dictionary = {} diff --git a/neo/VM/RandomAccessStack.py b/neo/VM/RandomAccessStack.py index 3ccda2b1a..47d2615c5 100644 --- a/neo/VM/RandomAccessStack.py +++ b/neo/VM/RandomAccessStack.py @@ -1,8 +1,4 @@ class RandomAccessStack: - _list = [] - _size = 0 # cache the size for performance - - _name = 'Stack' def __init__(self, name='Stack'): self._list = [] diff --git a/neo/VM/Script.py b/neo/VM/Script.py new file mode 100644 index 000000000..895ddea87 --- /dev/null +++ b/neo/VM/Script.py @@ -0,0 +1,28 @@ +class Script: + def __init__(self, crypto, script): + self._crypto = crypto + self._value = script + self._script_hash = None + + @property + def ScriptHash(self) -> bytearray: + if self._script_hash is None: + self._script_hash = self._crypto.Hash160(self._value) + return self._script_hash + + @property + def Length(self) -> int: + return len(self._value) + + def __call__(self, *args, **kwargs): + index = args[0] + return self._value[index] + + def __getitem__(self, item): + return self._value[item] + + @classmethod + def FromHash(cls, scrip_hash, script): + o = cls(None, script) + o._script_hash = scrip_hash + return o diff --git a/neo/VM/ScriptBuilder.py b/neo/VM/ScriptBuilder.py index 36dd5e247..9ce4411dc 100644 --- a/neo/VM/ScriptBuilder.py +++ b/neo/VM/ScriptBuilder.py @@ -8,7 +8,7 @@ import binascii from neo.VM.OpCode import PUSHDATA1, PUSHDATA2, PUSHDATA4, PUSHF, PUSHT, PACK, PUSH0, PUSH1, PUSHM1, PUSHBYTES75, \ APPCALL, TAILCALL, SYSCALL -from neo.IO.MemoryStream import MemoryStream +from neo.IO.MemoryStream import StreamManager from neo.Core.BigInteger import BigInteger @@ -17,7 +17,7 @@ class ScriptBuilder: def __init__(self): super(ScriptBuilder, self).__init__() - self.ms = MemoryStream() # MemoryStream + self.ms = StreamManager.GetStream() # MemoryStream def WriteUInt16(self, value, endian="<"): return self.pack('%sH' % endian, value) @@ -230,7 +230,7 @@ def EmitSysCallWithArguments(self, api, args): def ToArray(self, cleanup=True): retval = self.ms.ToArray() if cleanup: - self.ms.Cleanup() + StreamManager.ReleaseStream(self.ms) self.ms = None return retval diff --git a/neo/VM/tests/JsonTester.py b/neo/VM/tests/JsonTester.py index e3ad3184c..e59e6d2c8 100644 --- a/neo/VM/tests/JsonTester.py +++ b/neo/VM/tests/JsonTester.py @@ -7,10 +7,13 @@ from neo.VM.ExecutionEngine import ExecutionContext from neo.VM.RandomAccessStack import RandomAccessStack from neo.Core.Cryptography.Crypto import Crypto +from neo.Core.UInt160 import UInt160 from typing import Optional from neo.VM.VMState import VMStateStr from neo.VM.OpCode import ToName as OpcodeToName +from neo.VM.OpCode import RET from neo.VM import InteropService +from neo.VM.Debugger import Debugger class MessageProvider: @@ -35,7 +38,9 @@ def GetScript(self, script_hash: bytes) -> Optional[bytes]: return self.data.get(script_hash, None) def Add(self, script: bytearray) -> None: - self.data[Crypto.Hash160(script)] = script + h = bytearray(Crypto.Default().Hash160(script)) + h.reverse() + self.data[binascii.hexlify(h)] = script file_count = 0 @@ -73,11 +78,25 @@ def execute_test(data: dict): script_container = MessageProvider(message) # prepare script table - script_table = None # there are currently no tests that load a script table so I don't know the format or key value they'll use + scripts = test.get("scriptTable", None) + script_table = None + if scripts: + script_table = ScriptTable() + for entry in scripts: + try: + script = binascii.unhexlify(entry['script'][2:]) + script_table.Add(script) + except binascii.Error: + print(f"Skipping test {data['category']}-{data['name']}, cannot read script data") + test_count -= 1 + skipped_test_count += 1 + continue # create engine and run engine = ExecutionEngine(crypto=Crypto.Default(), service=service, container=script_container, table=script_table, exit_on_error=True) + debugger = Debugger(engine) + # TODO: should enforce 0x rule in the JSON test case if test['script'].startswith('0x'): script = test['script'][2:] @@ -101,13 +120,13 @@ def execute_test(data: dict): actions = step.get('actions', []) for action in actions: if action == "StepInto": - engine.StepInto() + debugger.StepInto() elif action == "Execute": - engine.Execute() + debugger.Execute() elif action == "StepOver": - raise ValueError("StepOver not supported!") + debugger.StepOver() elif action == "StepOut": - raise ValueError("StepOut not supported!") + debugger.StepOut() test_name = test.get("name", "") msg = f"{data['category']}-{data['name']}-{test_name}-{i}" @@ -136,8 +155,14 @@ def assert_invocation_stack(istack: RandomAccessStack, result: dict, msg: str): actual_script_hash = binascii.hexlify(actual_context.ScriptHash()).decode() assert actual_script_hash == expected_script_hash, f"[{msg}] Script hash differs! Expected: {expected_script_hash} Actual: {actual_script_hash}" + opcode = RET if actual_context.InstructionPointer >= actual_context.Script.Length else actual_context.Script[actual_context.InstructionPointer] expected_next_instruction = expected_context['nextInstruction'] - actual_next_instruction = OpcodeToName(actual_context.NextInstruction) + # hack to work around C#'s lack of having defined enum members for PUSHBYTES2-PUSHBYTES74 + # TODO: remove this once neo-vm is updated to have human readable names for the above enum members + if expected_next_instruction.isdecimal(): + expected_next_instruction = OpcodeToName(int(expected_next_instruction)) + + actual_next_instruction = OpcodeToName(opcode) assert actual_next_instruction == expected_next_instruction, f"[{msg}] Next instruction differs! Expected: {expected_next_instruction} Actual: {actual_next_instruction}" expected_ip = expected_context['instructionPointer'] diff --git a/neo/VM/tests/test_execution_engine.py b/neo/VM/tests/test_execution_engine.py index 0f0fd948c..d0ae78fee 100644 --- a/neo/VM/tests/test_execution_engine.py +++ b/neo/VM/tests/test_execution_engine.py @@ -3,6 +3,7 @@ from neo.VM.ExecutionEngine import ExecutionEngine from neo.VM.ExecutionEngine import ExecutionContext from neo.VM import OpCode +from neo.VM.Script import Script from neo.Core.Cryptography.Crypto import Crypto from mock import patch import binascii @@ -15,13 +16,15 @@ class VMTestCase(TestCase): def setUp(self): self.engine = ExecutionEngine(crypto=Crypto.Default()) - self.econtext = ExecutionContext(engine=self.engine) + self.econtext = ExecutionContext(Script(self.engine.Crypto, b''), 0) + self.engine.InvocationStack.PushT(self.econtext) def test_add_operations(self): self.econtext.EvaluationStack.PushT(StackItem.New(2)) self.econtext.EvaluationStack.PushT(StackItem.New(3)) + self.econtext.Script._value = OpCode.ADD - self.engine.ExecuteOp(OpCode.ADD, self.econtext) + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 1) @@ -30,8 +33,9 @@ def test_add_operations(self): def test_sub_operations(self): self.econtext.EvaluationStack.PushT(StackItem.New(2)) self.econtext.EvaluationStack.PushT(StackItem.New(3)) + self.econtext.Script._value = OpCode.SUB - self.engine.ExecuteOp(OpCode.SUB, self.econtext) + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 1) @@ -42,14 +46,16 @@ def test_verify_sig(self): self.econtext.EvaluationStack.PushT(stackItemMessage) # sig - sig = binascii.unhexlify(b'cd0ca967d11cea78e25ad16f15dbe77672258bfec59ff3617c95e317acff063a48d35f71aa5ce7d735977412186e1572507d0f4d204c5bcb6c90e03b8b857fbd') + sig = binascii.unhexlify( + b'cd0ca967d11cea78e25ad16f15dbe77672258bfec59ff3617c95e317acff063a48d35f71aa5ce7d735977412186e1572507d0f4d204c5bcb6c90e03b8b857fbd') self.econtext.EvaluationStack.PushT(StackItem.New(sig)) # pubkey pubkey = binascii.unhexlify(b'036fbcb5e138c1ce5360e861674c03228af735a9114a5b7fb4121b8350129f3ffe') self.econtext.EvaluationStack.PushT(pubkey) - self.engine.ExecuteOp(OpCode.VERIFY, self.econtext) + self.econtext.Script._value = OpCode.VERIFY + self.engine.ExecuteInstruction() res = self.econtext.EvaluationStack.Pop() self.assertEqual(res, StackItem.New(True)) @@ -60,14 +66,16 @@ def test_verify_sig_fail(self): self.econtext.EvaluationStack.PushT(stackItemMessage) # sig - sig = binascii.unhexlify(b'cd0ca967d11cea78e25ad16f15dbe77672258bfec59ff3617c95e317acff063a48d35f71aa5ce7d735977412186e1572507d0f4d204c5bcb6c90e03b8b857fbd') + sig = binascii.unhexlify( + b'cd0ca967d11cea78e25ad16f15dbe77672258bfec59ff3617c95e317acff063a48d35f71aa5ce7d735977412186e1572507d0f4d204c5bcb6c90e03b8b857fbd') self.econtext.EvaluationStack.PushT(StackItem.New(sig)) # pubkey pubkey = binascii.unhexlify(b'036fbcb5e138c1ce5360e861674c03228af735a9114a5b7fb4121b8350129f3ffd') self.econtext.EvaluationStack.PushT(pubkey) - self.engine.ExecuteOp(OpCode.VERIFY, self.econtext) + self.econtext.Script._value = OpCode.VERIFY + self.engine.ExecuteInstruction() res = self.econtext.EvaluationStack.Pop() self.assertEqual(res, StackItem.New(False)) diff --git a/neo/VM/tests/test_interop_blockchain.py b/neo/VM/tests/test_interop_blockchain.py index bfcadefb2..7aad5d9c7 100644 --- a/neo/VM/tests/test_interop_blockchain.py +++ b/neo/VM/tests/test_interop_blockchain.py @@ -7,6 +7,7 @@ from neo.Core.TX.Transaction import Transaction from neo.Settings import settings from neo.Core.UInt256 import UInt256 +from neo.VM.Script import Script import os @@ -30,7 +31,8 @@ def setUpClass(cls): def setUp(self): self.engine = ExecutionEngine() - self.econtext = ExecutionContext(engine=self.engine) + self.econtext = ExecutionContext(Script(self.engine.Crypto, b''), 0) + self.engine.InvocationStack.PushT(self.econtext) self.state_reader = StateReader() def test_interop_getblock(self): diff --git a/neo/VM/tests/test_interop_map.py b/neo/VM/tests/test_interop_map.py index 4e70f1cbf..0beb4eda9 100644 --- a/neo/VM/tests/test_interop_map.py +++ b/neo/VM/tests/test_interop_map.py @@ -4,6 +4,7 @@ from neo.VM.ExecutionContext import ExecutionContext from neo.VM import OpCode from neo.VM import VMState +from neo.VM.Script import Script import logging @@ -22,7 +23,8 @@ def setUpClass(cls): def setUp(self): self.engine = ExecutionEngine() - self.econtext = ExecutionContext(engine=self.engine) + self.econtext = ExecutionContext(Script(self.engine.Crypto, b''), 0) + self.engine.InvocationStack.PushT(self.econtext) def test_interop_map1(self): map = Map() @@ -72,17 +74,20 @@ def test_interop_map3(self): self.assertEqual(map.GetMap(), {'b': 2, 'c': 3, 'h': 9}) def test_op_map1(self): - self.engine.ExecuteOp(OpCode.NEWMAP, self.econtext) + self.econtext.Script._value = OpCode.NEWMAP + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 1) self.assertIsInstance(self.econtext.EvaluationStack.Items[0], Map) self.assertEqual(self.econtext.EvaluationStack.Items[0].GetMap(), {}) def test_op_map2(self): - self.engine.ExecuteOp(OpCode.NEWMAP, self.econtext) + self.econtext.Script._value = OpCode.NEWMAP + OpCode.SETITEM + self.engine.ExecuteInstruction() + self.econtext.EvaluationStack.PushT(StackItem.New('mykey')) self.econtext.EvaluationStack.PushT(StackItem.New('myVal')) - self.engine.ExecuteOp(OpCode.SETITEM, self.econtext) + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 0) @@ -91,9 +96,10 @@ def test_op_map3(self): self.econtext.EvaluationStack.PushT(StackItem.New('myvalue')) self.econtext.EvaluationStack.PushT(StackItem.New('mykey')) + self.econtext.Script._value = OpCode.SETITEM with self.assertRaises(Exception) as context: - self.engine.ExecuteOp(OpCode.SETITEM, self.econtext) + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 0) self.assertEqual(self.engine.State, VMState.BREAK) @@ -101,10 +107,12 @@ def test_op_map3(self): def test_op_map4(self): with self.assertLogHandler('vm', logging.DEBUG) as log_context: # set item should fail if these are out of order + self.econtext.Script._value = OpCode.NEWMAP + OpCode.SETITEM + self.econtext.EvaluationStack.PushT(StackItem.New('mykey')) - self.engine.ExecuteOp(OpCode.NEWMAP, self.econtext) + self.engine.ExecuteInstruction() self.econtext.EvaluationStack.PushT(StackItem.New('myVal')) - self.engine.ExecuteOp(OpCode.SETITEM, self.econtext) + self.engine.ExecuteInstruction() self.assertEqual(self.engine.State, VMState.FAULT) self.assertTrue(len(log_context.output) > 0) @@ -117,7 +125,8 @@ def test_op_map5(self): self.econtext.EvaluationStack.PushT(StackItem.New('mykey')) self.econtext.EvaluationStack.PushT(StackItem.New('mykey')) self.econtext.EvaluationStack.PushT(StackItem.New('myVal')) - self.engine.ExecuteOp(OpCode.SETITEM, self.econtext) + self.econtext.Script._value = OpCode.SETITEM + self.engine.ExecuteInstruction() self.assertEqual(self.engine.State, VMState.FAULT) @@ -129,7 +138,8 @@ def test_op_map6(self): # we can pick an item from a dict self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4)})) self.econtext.EvaluationStack.PushT(StackItem.New('a')) - self.engine.ExecuteOp(OpCode.PICKITEM, self.econtext) + self.econtext.Script._value = OpCode.PICKITEM + self.engine.ExecuteInstruction() self.assertEqual(len(self.econtext.EvaluationStack.Items), 1) self.assertEqual(self.econtext.EvaluationStack.Items[0].GetBigInteger(), 4) @@ -139,7 +149,8 @@ def test_op_map7(self): # pick item with key is collection causes error self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4)})) self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4)})) - self.engine.ExecuteOp(OpCode.PICKITEM, self.econtext) + self.econtext.Script._value = OpCode.PICKITEM + self.engine.ExecuteInstruction() self.assertEqual(self.engine.State, VMState.FAULT) self.assertTrue(len(log_context.output) > 0) @@ -150,7 +161,8 @@ def test_op_map8(self): # pick item out of bounds self.econtext.EvaluationStack.PushT(StackItem.New('a')) self.econtext.EvaluationStack.PushT(StackItem.New('a')) - self.engine.ExecuteOp(OpCode.PICKITEM, self.econtext) + self.econtext.Script._value = OpCode.PICKITEM + self.engine.ExecuteInstruction() self.assertTrue(len(log_context.output) > 0) log_msg = log_context.output[0] @@ -164,7 +176,8 @@ def test_op_map9(self): # pick item key not found self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4)})) self.econtext.EvaluationStack.PushT(StackItem.New('b')) - self.engine.ExecuteOp(OpCode.PICKITEM, self.econtext) + self.econtext.Script._value = OpCode.PICKITEM + self.engine.ExecuteInstruction() self.assertEqual(self.engine.State, VMState.FAULT) self.assertTrue(len(log_context.output) > 0) @@ -173,7 +186,8 @@ def test_op_map9(self): def test_op_map10(self): # pick item key not found self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4), StackItem.New('b'): StackItem.New(5)})) - self.engine.ExecuteOp(OpCode.KEYS, self.econtext) + self.econtext.Script._value = OpCode.KEYS + self.engine.ExecuteInstruction() self.assertIsInstance(self.econtext.EvaluationStack.Items[0], Array) items = self.econtext.EvaluationStack.Items[0].GetArray() @@ -181,7 +195,8 @@ def test_op_map10(self): def test_op_map11(self): self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4), StackItem.New('b'): StackItem.New(5)})) - self.engine.ExecuteOp(OpCode.VALUES, self.econtext) + self.econtext.Script._value = OpCode.VALUES + self.engine.ExecuteInstruction() self.assertIsInstance(self.econtext.EvaluationStack.Items[0], Array) items = self.econtext.EvaluationStack.Items[0].GetArray() @@ -190,13 +205,15 @@ def test_op_map11(self): def test_op_map12(self): self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4), StackItem.New('b'): StackItem.New(5)})) self.econtext.EvaluationStack.PushT(StackItem.New('b')) - self.engine.ExecuteOp(OpCode.HASKEY, self.econtext) + self.econtext.Script._value = OpCode.HASKEY + self.engine.ExecuteInstruction() self.assertEqual(self.econtext.EvaluationStack.Items[0].GetBoolean(), True) def test_op_map13(self): self.econtext.EvaluationStack.PushT(Map(dict={StackItem.New('a'): StackItem.New(4), StackItem.New('b'): StackItem.New(5)})) self.econtext.EvaluationStack.PushT(StackItem.New('c')) - self.engine.ExecuteOp(OpCode.HASKEY, self.econtext) + self.econtext.Script._value = OpCode.HASKEY + self.engine.ExecuteInstruction() self.assertEqual(self.econtext.EvaluationStack.Items[0].GetBoolean(), False) diff --git a/neo/VM/tests/test_interop_serialize.py b/neo/VM/tests/test_interop_serialize.py index d5b12d7d4..d3cf9a689 100644 --- a/neo/VM/tests/test_interop_serialize.py +++ b/neo/VM/tests/test_interop_serialize.py @@ -6,6 +6,8 @@ from neo.Core.IO.BinaryReader import BinaryReader from neo.IO.MemoryStream import StreamManager from neo.Core.Blockchain import Blockchain +from neo.VM.Script import Script +from neo.Core.Cryptography.Crypto import Crypto import logging @@ -15,8 +17,9 @@ class InteropSerializeDeserializeTestCase(NeoTestCase): state_reader = None def setUp(self): - self.engine = ExecutionEngine() - self.econtext = ExecutionContext(engine=self.engine) + self.engine = ExecutionEngine(crypto=Crypto.Default()) + self.econtext = ExecutionContext(Script(self.engine.Crypto, b''), 0) + self.engine.InvocationStack.PushT(self.econtext) self.state_reader = StateReader() def test_serialize_struct(self): diff --git a/neo/Wallets/Coin.py b/neo/Wallets/Coin.py index adbb7590d..3aef6cc0b 100644 --- a/neo/Wallets/Coin.py +++ b/neo/Wallets/Coin.py @@ -11,13 +11,6 @@ class Coin(TrackableMixin): - Output = None - Reference = None - - _address = None - _state = CoinState.Unconfirmed - _transaction = None - @staticmethod def CoinFromRef(coin_ref, tx_output, state=CoinState.Unconfirmed, transaction=None): """ @@ -47,6 +40,10 @@ def __init__(self, prev_hash=None, prev_index=None, tx_output=None, coin_referen coin_reference (neo.Core.CoinReference): (Optional if prev_hash and prev_index are given) an object representing a single UTXO / transaction input. state (neo.Core.State.CoinState): """ + + self._address = None + self._transaction = None + if prev_hash and prev_index: self.Reference = CoinReference(prev_hash, prev_index) elif coin_reference: diff --git a/neo/Wallets/Wallet.py b/neo/Wallets/Wallet.py index e2b1fd98d..09c19b9c4 100755 --- a/neo/Wallets/Wallet.py +++ b/neo/Wallets/Wallet.py @@ -6,6 +6,7 @@ """ import traceback import hashlib +import asyncio from itertools import groupby from base58 import b58decode from decimal import Decimal @@ -35,22 +36,6 @@ class Wallet: - AddressVersion = None - - _path = '' - _iv = None - _master_key = None - _keys = {} # holds keypairs - _contracts = {} # holds Contracts - _tokens = {} # holds references to NEP5 tokens - _watch_only = [] # holds set of hashes - _coins = {} # holds Coin References - - _current_height = 0 - - _vin_exclude = None - - _lock = None # allows locking for threads that may need to access the DB concurrently (e.g. ProcessBlocks and Rebuild) @property def WalletHeight(self): @@ -66,7 +51,10 @@ def __init__(self, path, passwordKey, create): passwordKey (aes_key): A password that has been converted to aes key with neo.Wallets.utils.to_aes_key create (bool): Whether to create the wallet or simply open. """ - + self._tokens = {} # holds references to NEP5 tokens + self._watch_only = [] # holds set of hashes + self._current_height = 0 + self._vin_exclude = None self.AddressVersion = settings.ADDRESS_VERSION self._path = path self._lock = RLock() @@ -1016,7 +1004,7 @@ def MakeTransaction(self, skip_fee_calc (bool): If true, the network fee calculation and verification will be skipped. Returns: - tx: (Transaction) Returns the transaction with oupdated inputs and outputs. + tx: (Transaction) Returns the transaction with updated inputs and outputs. """ tx.ResetReferences() @@ -1111,7 +1099,8 @@ def MakeTransaction(self, if req_fee < settings.LOW_PRIORITY_THRESHOLD: req_fee = settings.LOW_PRIORITY_THRESHOLD if fee < req_fee: - raise TXFeeError(f'Transaction cancelled. The tx size ({tx.Size()}) exceeds the max free tx size ({settings.MAX_FREE_TX_SIZE}).\nA network fee of {req_fee.ToString()} GAS is required.') + raise TXFeeError( + f'Transaction cancelled. The tx size ({tx.Size()}) exceeds the max free tx size ({settings.MAX_FREE_TX_SIZE}).\nA network fee of {req_fee.ToString()} GAS is required.') return tx @@ -1270,6 +1259,22 @@ def IsSynced(self): def pretty_print(self, verbose=False): pass + async def sync_wallet(self, start_block, rebuild=False): + Blockchain.Default().PersistCompleted.on_change -= self.ProcessNewBlock + + if rebuild: + self.Rebuild(start_block) + while True: + # trying with 100, might need to lower if processing takes too long + self.ProcessBlocks(block_limit=100) + + if self.IsSynced: + break + # give some time to other tasks + await asyncio.sleep(0.05) + + Blockchain.Default().PersistCompleted.on_change += self.ProcessNewBlock + def ToJson(self, verbose=False): # abstract pass diff --git a/neo/api/JSONRPC/JsonRpcApi.py b/neo/api/JSONRPC/JsonRpcApi.py index e8206cf4b..fb7394ddd 100644 --- a/neo/api/JSONRPC/JsonRpcApi.py +++ b/neo/api/JSONRPC/JsonRpcApi.py @@ -1,43 +1,42 @@ """ -The JSON-RPC API is using the Python package 'klein', which makes it possible to -create HTTP routes and handlers with Twisted in a similar style to Flask: -https://github.com/twisted/klein +The JSON-RPC API is using the Python package 'aioHttp' See also: * http://www.jsonrpc.org/specification """ -import json -import base58 +import os +import ast import binascii +import logging from json.decoder import JSONDecodeError -from klein import Klein +import aiohttp_cors +import base58 +from aiohttp import web +from aiohttp.helpers import MultiDict -from neo.Settings import settings from neo.Core.Blockchain import Blockchain -from neo.api.utils import json_response, cors_header +from neo.Core.Fixed8 import Fixed8 +from neo.Core.Helper import Helper from neo.Core.State.AccountState import AccountState +from neo.Core.State.CoinState import CoinState +from neo.Core.State.StorageKey import StorageKey from neo.Core.TX.Transaction import Transaction, TransactionOutput, \ ContractTransaction, TXFeeError from neo.Core.TX.TransactionAttribute import TransactionAttribute, \ TransactionAttributeUsage -from neo.Core.State.CoinState import CoinState from neo.Core.UInt160 import UInt160 from neo.Core.UInt256 import UInt256 -from neo.Core.Fixed8 import Fixed8 -from neo.Core.Helper import Helper -from neo.Network.NodeLeader import NodeLeader -from neo.Core.State.StorageKey import StorageKey +from neo.Implementations.Wallets.peewee.Models import Account +from neo.Network.nodemanager import NodeManager +from neo.Prompt.Utils import get_asset_id +from neo.Settings import settings from neo.SmartContract.ApplicationEngine import ApplicationEngine from neo.SmartContract.ContractParameter import ContractParameter from neo.SmartContract.ContractParameterContext import ContractParametersContext from neo.VM.ScriptBuilder import ScriptBuilder from neo.VM.VMState import VMStateStr -from neo.Implementations.Wallets.peewee.Models import Account -from neo.Prompt.Utils import get_asset_id -from neo.Wallets.Wallet import Wallet -from furl import furl -import ast +from neo.api.utils import json_response class JsonRpcError(Exception): @@ -79,12 +78,36 @@ def internalError(message=None): class JsonRpcApi: - app = Klein() - port = None - def __init__(self, port, wallet=None): - self.port = port + def __init__(self, wallet=None): + if not os.getenv("NEOPYTHON_UNITTEST"): + stdio_handler = logging.StreamHandler() + stdio_handler.setLevel(logging.INFO) + _logger = logging.getLogger('aiohttp.access') + _logger.addHandler(stdio_handler) + _logger.setLevel(logging.DEBUG) + + self.app = web.Application(logger=_logger) + else: + self.app = web.Application() + self.port = settings.RPC_PORT self.wallet = wallet + self.nodemgr = NodeManager() + + cors = aiohttp_cors.setup(self.app, defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_headers=('Content-Type', 'Access-Control-Allow-Headers', 'Authorization', 'X-Requested-With') + ) + }) + + self.app.router.add_post("/", self.home) + # TODO: find a fix for adding an OPTIONS route in combination with CORS. It works fine without CORS + # self.app.router.add_options("/", self.home) + self.app.router.add_get("/", self.home) + + for route in list(self.app.router.routes()): + if not isinstance(route.resource, web.StaticResource): # <<< WORKAROUND + cors.add(route) def get_data(self, body: dict): @@ -117,11 +140,12 @@ def get_data(self, body: dict): # # JSON-RPC API Route - # - @app.route('/') + # TODO: re-enable corse_header support + # fix tests + # someday rewrite to allow async methods, have another look at https://github.com/bcb/jsonrpcserver/tree/master/jsonrpcserver + # the only downside of that plugin is that it does not support custom errors. Either patch or request @json_response - @cors_header - def home(self, request): + async def home(self, request): # POST Examples: # {"jsonrpc": "2.0", "id": 5, "method": "getblockcount", "params": []} # or multiple requests in 1 transaction @@ -132,9 +156,10 @@ def home(self, request): # NOTE: GET requests do not support multiple requests in 1 transaction request_id = None - if "POST" == request.method.decode("utf-8"): + if "POST" == request.method: try: - content = json.loads(request.content.read().decode("utf-8")) + content = await request.json() + # content = json.loads(content.decode('utf-8')) # test if it's a multi-request message if isinstance(content, list): @@ -150,36 +175,19 @@ def home(self, request): error = JsonRpcError.parseError() return self.get_custom_error_payload(request_id, error.code, error.message) - elif "GET" == request.method.decode("utf-8"): - content = furl(request.uri).args - - # remove hanging ' or " from last value if value is not None to avoid SyntaxError - try: - l_value = list(content.values())[-1] - except IndexError: - error = JsonRpcError.parseError() - return self.get_custom_error_payload(request_id, error.code, error.message) - - if l_value is not None: - n_value = l_value[:-1] - l_key = list(content.keys())[-1] - content[l_key] = n_value - - if len(content.keys()) > 3: - try: - params = content['params'] - l_params = ast.literal_eval(params) - content['params'] = [l_params] - except KeyError: - error = JsonRpcError(-32602, "Invalid params") - return self.get_custom_error_payload(request_id, error.code, error.message) + elif "GET" == request.method: + content = MultiDict(request.query) + params = content.get("params", None) + if params and not isinstance(params, list): + new_params = ast.literal_eval(params) + content.update({'params': new_params}) return self.get_data(content) - elif "OPTIONS" == request.method.decode("utf-8"): + elif "OPTIONS" == request.method: return self.options_response() - error = JsonRpcError.invalidRequest("%s is not a supported HTTP method" % request.method.decode("utf-8")) + error = JsonRpcError.invalidRequest("%s is not a supported HTTP method" % request.method) return self.get_custom_error_payload(request_id, error.code, error.message) @classmethod @@ -241,7 +249,7 @@ def json_rpc_method_handler(self, method, params): raise JsonRpcError(-100, "Invalid Height") elif method == "getconnectioncount": - return len(NodeLeader.Instance().Peers) + return len(self.nodemgr.nodes) elif method == "getcontractstate": script_hash = UInt160.ParseString(params[0]) @@ -251,12 +259,12 @@ def json_rpc_method_handler(self, method, params): return contract.ToJson() elif method == "getrawmempool": - return list(map(lambda hash: "0x%s" % hash.decode('utf-8'), NodeLeader.Instance().MemPool.keys())) + return list(map(lambda hash: f"{hash.To0xString()}", self.nodemgr.mempool.pool.keys())) elif method == "getversion": return { "port": self.port, - "nonce": NodeLeader.Instance().NodeId, + "nonce": self.nodemgr.id, "useragent": settings.VERSION_NAME } @@ -320,7 +328,8 @@ def json_rpc_method_handler(self, method, params): elif method == "sendrawtransaction": tx_script = binascii.unhexlify(params[0].encode('utf-8')) transaction = Transaction.DeserializeFromBufer(tx_script) - result = NodeLeader.Instance().Relay(transaction) + # TODO: relay blocks, change to await in the future + result = self.nodemgr.relay(transaction) return result elif method == "validateaddress": @@ -451,26 +460,20 @@ def validateaddress(self, params): def get_peers(self): """Get all known nodes and their 'state' """ - node = NodeLeader.Instance() + result = {"connected": [], "unconnected": [], "bad": []} - connected_peers = [] - for peer in node.Peers: - result['connected'].append({"address": peer.host, - "port": peer.port}) - connected_peers.append("{}:{}".format(peer.host, peer.port)) + for peer in self.nodemgr.nodes: + host, port = peer.address.rsplit(":") + result['connected'].append({"address": host, "port": int(port)}) - for addr in node.DEAD_ADDRS: + for addr in self.nodemgr.bad_addresses: host, port = addr.rsplit(':', 1) - result['bad'].append({"address": host, "port": port}) + result['bad'].append({"address": host, "port": int(port)}) - # "UnconnectedPeers" is never used. So a check is needed to - # verify that a given address:port does not belong to a connected peer - for addr in node.KNOWN_ADDRS: + for addr in self.nodemgr.known_addresses: host, port = addr.rsplit(':', 1) - if addr not in connected_peers: - result['unconnected'].append({"address": host, - "port": int(port)}) + result['unconnected'].append({"address": host, "port": int(port)}) return result @@ -636,7 +639,7 @@ def process_transaction(self, contract_tx, fee=None, address_from=None, change_a if context.Completed: tx.scripts = context.GetScripts() self.wallet.SaveTransaction(tx) - NodeLeader.Instance().Relay(tx) + self.nodemgr.relay(tx) return tx.ToJson() else: return context.ToJson() diff --git a/neo/api/JSONRPC/test_json_invoke_rpc_api.py b/neo/api/JSONRPC/test_json_invoke_rpc_api.py index d0630c791..6df50fe92 100644 --- a/neo/api/JSONRPC/test_json_invoke_rpc_api.py +++ b/neo/api/JSONRPC/test_json_invoke_rpc_api.py @@ -4,72 +4,58 @@ $ python -m unittest neo.api.JSONRPC.test_json_rpc_api """ import json -import pprint -import binascii import os -from klein.test.test_resource import requestMock -from twisted.web import server -from twisted.web.test.test_web import DummyChannel -from neo import __version__ -from neo.api.JSONRPC.JsonRpcApi import JsonRpcApi -from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase -from neo.IO.Helper import Helper +from aiohttp.test_utils import AioHTTPTestCase + +from neo.Settings import settings from neo.SmartContract.ContractParameter import ContractParameter from neo.SmartContract.ContractParameterType import ContractParameterType -from neo.Core.UInt160 import UInt160 +from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase from neo.VM import VMState from neo.VM.VMState import VMStateStr -from neo.Settings import settings +from neo.api.JSONRPC.JsonRpcApi import JsonRpcApi -def mock_post_request(body): - return requestMock(path=b'/', method=b"POST", body=body) +class JsonRpcInvokeApiTestCase(BlockchainFixtureTestCase, AioHTTPTestCase): + def __init__(self, *args, **kwargs): + super(JsonRpcInvokeApiTestCase, self).__init__(*args, **kwargs) + self.api_server = JsonRpcApi() -def mock_get_request(path, method=b"GET"): - request = server.Request(DummyChannel(), False) - request.uri = path - request.method = method - request.clientproto = b'HTTP/1.1' - return request + async def get_application(self): + """ + Override the get_app method to return your application. + """ + return self.api_server.app -class JsonRpcInvokeApiTestCase(BlockchainFixtureTestCase): - app = None # type:JsonRpcApi + def do_test_get(self, url, data=None): + async def test_get_route(url, data=None): + resp = await self.client.get(url, data=data) + text = await resp.text() + return text - @classmethod - def leveldb_testpath(cls): - return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') - - def setUp(self): - self.app = JsonRpcApi(9479) + return self.loop.run_until_complete(test_get_route(url, data)) - def test_invalid_request_method(self): - # test HEAD method - mock_req = mock_get_request(b'/?test', b"HEAD") - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) - self.assertEqual(res["error"]["message"], 'HEAD is not a supported HTTP method') + def do_test_post(self, url, data=None, json=None): + if data is not None and json is not None: + raise ValueError("cannot specify `data` and `json` at the same time") - def test_invalid_json_payload(self): - # test POST requests - mock_req = mock_post_request(b"{ invalid") - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32700) + async def test_get_route(url, data=None, json=None): + if data: + resp = await self.client.post(url, data=data) + else: + resp = await self.client.post(url, json=json) - mock_req = mock_post_request(json.dumps({"some": "stuff"}).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) + text = await resp.text() + return text - # test GET requests - mock_req = mock_get_request(b"/?%20invalid") # equivalent to "/? invalid" - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) + return self.loop.run_until_complete(test_get_route(url, data, json)) - mock_req = mock_get_request(b"/?some=stuff") - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) + @classmethod + def leveldb_testpath(cls): + return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') def _gen_post_rpc_req(self, method, params=None, request_id="2"): ret = { @@ -85,7 +71,7 @@ def _gen_get_rpc_req(self, method, params=None, request="2"): ret = "/?jsonrpc=2.0&method=%s¶ms=[]&id=%s" % (method, request) if params: ret = "/?jsonrpc=2.0&method=%s¶ms=%s&id=%s" % (method, params, request) - return ret.encode('utf-8') + return ret def test_invoke_1(self): # test POST requests @@ -101,8 +87,7 @@ def test_invoke_1(self): } ] req = self._gen_post_rpc_req("invoke", params=[contract_hash, jsn]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) self.assertEqual(res['result']['gas_consumed'], '0.128') results = [] @@ -113,9 +98,8 @@ def test_invoke_1(self): self.assertEqual(results[0].Value, bytearray(b'NEX Template V2')) # test GET requests - req = self._gen_get_rpc_req("invoke", params=[contract_hash, jsn]) - mock_req = mock_get_request(req) - res = json.loads(self.app.home(mock_req)) + url = self._gen_get_rpc_req("invoke", params=[contract_hash, jsn]) + res = json.loads(self.do_test_get(url)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) self.assertEqual(res['result']['gas_consumed'], '0.128') results = [] @@ -143,8 +127,7 @@ def test_invoke_2(self): } ] req = self._gen_post_rpc_req("invoke", params=[contract_hash, jsn]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) results = [] for p in res['result']['stack']: @@ -156,8 +139,7 @@ def test_invoke_2(self): def test_invoke_3(self): contract_hash = 'b9fbcff6e50fd381160b822207231233dd3c56c2' req = self._gen_post_rpc_req("invokefunction", params=[contract_hash, 'symbol']) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) results = [] for p in res['result']['stack']: @@ -172,8 +154,7 @@ def test_invoke_4(self): 'value': bytearray(b'#\xba\'\x03\xc52c\xe8\xd6\xe5"\xdc2 39\xdc\xd8\xee\xe9').hex()}] req = self._gen_post_rpc_req("invokefunction", params=[contract_hash, 'balanceOf', params]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) results = [] for p in res['result']['stack']: @@ -186,8 +167,7 @@ def test_invoke_5(self): test_script = "00046e616d6567c2563cdd3312230722820b1681d30fe5f6cffbb9000673796d626f6c67c2563cdd3312230722820b1681d30fe5f6cffbb90008646563696d616c7367c2563cdd3312230722820b1681d30fe5f6cffbb9" req = self._gen_post_rpc_req("invokescript", params=[test_script]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['state'], VMStateStr(VMState.HALT)) results = [] @@ -202,15 +182,13 @@ def test_invoke_5(self): def test_bad_invoke_script(self): test_script = '0zzzzzzef3e30b007cd98d67d7' req = self._gen_post_rpc_req("invokescript", params=[test_script]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertIn('Non-hexadecimal digit found', res['error']['message']) def test_bad_invoke_script_2(self): test_script = '00046e616d656754a64cac1b103e662933ef3e30b007cd98d67d7000673796d626f6c6754a64cac1b1073e662933ef3e30b007cd98d67d70008646563696d616c736754a64cac1b1073e662933ef3e30b007cd98d67d7' req = self._gen_post_rpc_req("invokescript", params=[test_script]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertIn('Odd-length string', res['error']['message']) diff --git a/neo/api/JSONRPC/test_json_rpc_api.py b/neo/api/JSONRPC/test_json_rpc_api.py index 4ea87e9ed..cd782c78c 100644 --- a/neo/api/JSONRPC/test_json_rpc_api.py +++ b/neo/api/JSONRPC/test_json_rpc_api.py @@ -3,191 +3,197 @@ $ python -m unittest neo.api.JSONRPC.test_json_rpc_api """ -import json import binascii +import json import os import shutil from tempfile import mkdtemp -from klein.test.test_resource import requestMock -from twisted.web import server -from twisted.web.test.test_web import DummyChannel +from unittest import SkipTest + +from aiohttp.test_utils import AioHTTPTestCase +from mock import patch from neo import __version__ -from neo.api.JSONRPC.JsonRpcApi import JsonRpcApi -from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase -from neo.Implementations.Wallets.peewee.UserWallet import UserWallet -from neo.Wallets.utils import to_aes_key -from neo.IO.Helper import Helper -from neo.Core.UInt256 import UInt256 from neo.Blockchain import GetBlockchain -from neo.Network.NodeLeader import NodeLeader -from neo.Network.NeoNode import NeoNode -from copy import deepcopy +from neo.IO.Helper import Helper +from neo.Implementations.Wallets.peewee.UserWallet import UserWallet +from neo.Network.node import NeoNode from neo.Settings import ROOT_INSTALL_PATH, settings +from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase from neo.Utils.WalletFixtureTestCase import WalletFixtureTestCase -from mock import patch - - -def mock_post_request(body): - return requestMock(path=b'/', method=b"POST", body=body) - +from neo.Wallets.utils import to_aes_key +from neo.api.JSONRPC.JsonRpcApi import JsonRpcApi +from neo.Network.nodemanager import NodeManager -def mock_get_request(path, method=b"GET"): - request = server.Request(DummyChannel(), False) - request.uri = path - request.method = method - request.clientproto = b'HTTP/1.1' - return request +class JsonRpcApiTestCase(BlockchainFixtureTestCase, AioHTTPTestCase): -class JsonRpcApiTestCase(BlockchainFixtureTestCase): - app = None # type:JsonRpcApi + def __init__(self, *args, **kwargs): + super(JsonRpcApiTestCase, self).__init__(*args, **kwargs) + self.api_server = JsonRpcApi() @classmethod def leveldb_testpath(cls): return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') - def setUp(self): - self.app = JsonRpcApi(20332) + async def get_application(self): + """ + Override the get_app method to return your application. + """ + + return self.api_server.app + + def do_test_get(self, url, data=None): + async def test_get_route(url, data=None): + resp = await self.client.get(url, data=data) + text = await resp.text() + return text + + return self.loop.run_until_complete(test_get_route(url, data)) + + def do_test_post(self, url, data=None, json=None): + if data is not None and json is not None: + raise ValueError("cannot specify `data` and `json` at the same time") + + async def test_get_route(url, data=None, json=None): + if data: + resp = await self.client.post(url, data=data) + else: + resp = await self.client.post(url, json=json) + + text = await resp.text() + return text + + return self.loop.run_until_complete(test_get_route(url, data, json)) + + def _gen_post_rpc_req(self, method, params=None, request_id="2"): + ret = { + "jsonrpc": "2.0", + "id": request_id, + "method": method + } + if params: + ret["params"] = params + return ret + + def _gen_get_rpc_req(self, method, params=None, request="2"): + ret = "/?jsonrpc=2.0&id=%s&method=%s¶ms=[]" % (request, method) + if params: + ret = "/?jsonrpc=2.0&id=%s&method=%s¶ms=%s" % (request, method, params) + return ret + @SkipTest def test_HTTP_OPTIONS_request(self): - mock_req = mock_get_request(b'/?test', b"OPTIONS") - res = json.loads(self.app.home(mock_req)) + # see constructor of JsonRPC api why we're skipping. CORS related + async def test_get_route(): + resp = await self.client.options("/") + text = await resp.text() + return json.loads(text) + + res = self.loop.run_until_complete(test_get_route()) self.assertTrue("GET" in res['supported HTTP methods']) self.assertTrue("POST" in res['supported HTTP methods']) self.assertTrue("default" in res['JSON-RPC server type']) + @SkipTest def test_invalid_request_method(self): # test HEAD method - mock_req = mock_get_request(b'/?test', b"HEAD") - res = json.loads(self.app.home(mock_req)) + async def test_get_route(): + resp = await self.client.head("/?test") + text = await resp.text() + return json.loads(text) + + res = self.loop.run_until_complete(test_get_route()) + self.assertEqual(res["error"]["code"], -32600) self.assertEqual(res["error"]["message"], 'HEAD is not a supported HTTP method') def test_invalid_json_payload(self): # test POST requests - mock_req = mock_post_request(b"{ invalid") - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", b"{ invalid")) self.assertEqual(res["error"]["code"], -32700) - mock_req = mock_post_request(json.dumps({"some": "stuff"}).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json.dumps({"some": "stuff"}).encode("utf-8"))) self.assertEqual(res["error"]["code"], -32600) # test GET requests - mock_req = mock_get_request(b"/") # equivalent to "/" - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32700) - - mock_req = mock_get_request(b"/?%20invalid") # equivalent to "/? invalid" - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_get("/")) self.assertEqual(res["error"]["code"], -32600) - mock_req = mock_get_request(b"/?some=stuff") - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_get("/?%20invalid")) # equivalent to "/? invalid" self.assertEqual(res["error"]["code"], -32600) - def _gen_post_rpc_req(self, method, params=None, request_id="2"): - ret = { - "jsonrpc": "2.0", - "id": request_id, - "method": method - } - if params: - ret["params"] = params - return ret - - def _gen_get_rpc_req(self, method, params=None, request="2"): - ret = "/?jsonrpc=2.0&id=%s&method=%s¶ms=[]" % (request, method) - if params: - ret = "/?jsonrpc=2.0&id=%s&method=%s¶ms=%s" % (request, method, params) - return ret.encode('utf-8') + res = json.loads(self.do_test_get("/?some=stuff")) + self.assertEqual(res["error"]["code"], -32600) def test_initial_setup(self): self.assertTrue(GetBlockchain().GetBlock(0).Hash.To0xString(), '0x996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099') - def test_GET_request_bad_params(self): - req = "/?jsonrpc=2.0&method=getblockcount¶m=[]&id=2" # "params" is misspelled - mock_req = mock_get_request(req) - res = json.loads(self.app.home(mock_req)) - - error = res.get('error', {}) - self.assertEqual(error.get('code', None), -32602) - self.assertEqual(error.get('message', None), "Invalid params") - def test_missing_fields(self): # test POST requests req = self._gen_post_rpc_req("foo") del req["jsonrpc"] - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) - self.assertEqual(res["error"]["message"], "Invalid value for 'jsonrpc'") + res = json.loads(self.do_test_post("/", data=req)) + self.assertEqual(res["error"]["code"], -32700) + self.assertEqual(res["error"]["message"], "Parse error") req = self._gen_post_rpc_req("foo") del req["id"] - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) - self.assertEqual(res["error"]["message"], "Field 'id' is missing") + res = json.loads(self.do_test_post("/", data=req)) + self.assertEqual(res["error"]["code"], -32700) + self.assertEqual(res["error"]["message"], "Parse error") req = self._gen_post_rpc_req("foo") del req["method"] - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res["error"]["code"], -32600) - self.assertEqual(res["error"]["message"], "Field 'method' is missing") + res = json.loads(self.do_test_post("/", data=req)) + self.assertEqual(res["error"]["code"], -32700) + self.assertEqual(res["error"]["message"], "Parse error") # test GET requests - mock_req = mock_get_request(b"/?method=foo&id=2") - res = json.loads(self.app.home(mock_req)) + url = "/?method=foo&id=2" + res = json.loads(self.do_test_get(url, data=req)) self.assertEqual(res["error"]["code"], -32600) self.assertEqual(res["error"]["message"], "Invalid value for 'jsonrpc'") - mock_req = mock_get_request(b"/?jsonrpc=2.0&method=foo") - res = json.loads(self.app.home(mock_req)) + url = "/?jsonrpc=2.0&method=foo" + res = json.loads(self.do_test_get(url, data=req)) self.assertEqual(res["error"]["code"], -32600) self.assertEqual(res["error"]["message"], "Field 'id' is missing") - mock_req = mock_get_request(b"/?jsonrpc=2.0&id=2") - res = json.loads(self.app.home(mock_req)) + url = "/?jsonrpc=2.0&id=2" + res = json.loads(self.do_test_get(url, data=req)) self.assertEqual(res["error"]["code"], -32600) self.assertEqual(res["error"]["message"], "Field 'method' is missing") def test_invalid_method(self): # test POST requests req = self._gen_post_rpc_req("invalid", request_id="42") - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res["id"], "42") self.assertEqual(res["error"]["code"], -32601) self.assertEqual(res["error"]["message"], "Method not found") # test GET requests - req = self._gen_get_rpc_req("invalid") - mock_req = mock_get_request(req) - res = json.loads(self.app.home(mock_req)) + url = self._gen_get_rpc_req("invalid") + res = json.loads(self.do_test_get(url)) self.assertEqual(res["error"]["code"], -32601) self.assertEqual(res["error"]["message"], "Method not found") def test_getblockcount(self): # test POST requests req = self._gen_post_rpc_req("getblockcount") - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(GetBlockchain().Height + 1, res["result"]) # test GET requests ...next we will test a complex method; see test_sendmany_complex - req = self._gen_get_rpc_req("getblockcount") - mock_req = mock_get_request(req) - res = json.loads(self.app.home(mock_req)) + url = self._gen_get_rpc_req("getblockcount") + res = json.loads(self.do_test_get(url)) self.assertEqual(GetBlockchain().Height + 1, res["result"]) def test_getblockhash(self): req = self._gen_post_rpc_req("getblockhash", params=[2]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) # taken from neoscan expected_blockhash = '0x049db9f55ac45201c128d1a40d0ef9d4bdc58db97d47d985ce8d66511a1ef9eb' @@ -195,16 +201,14 @@ def test_getblockhash(self): def test_getblockhash_failure(self): req = self._gen_post_rpc_req("getblockhash", params=[-1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(-100, res["error"]["code"]) self.assertEqual("Invalid Height", res["error"]["message"]) def test_account_state(self): addr_str = 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp8y' req = self._gen_post_rpc_req("getaccountstate", params=[addr_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['balances'][0]['value'], '99989900') self.assertEqual(res['result']['balances'][0]['asset'], '0xc56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b'), self.assertEqual(res['result']['address'], addr_str) @@ -212,16 +216,14 @@ def test_account_state(self): def test_account_state_not_existing_yet(self): addr_str = 'AHozf8x8GmyLnNv8ikQcPKgRHQTbFi46u2' req = self._gen_post_rpc_req("getaccountstate", params=[addr_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['balances'], []) self.assertEqual(res['result']['address'], addr_str) def test_account_state_failure(self): addr_str = 'AK2nJJpJr6o664CWJKi1QRXjqeic2zRp81' req = self._gen_post_rpc_req("getaccountstate", params=[addr_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(-2146233033, res['error']['code']) self.assertEqual('One of the identified items was in an invalid format.', res['error']['message']) @@ -229,8 +231,7 @@ def test_account_state_failure(self): def test_get_asset_state_hash(self): asset_str = '602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7' req = self._gen_post_rpc_req("getassetstate", params=[asset_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['assetId'], '0x%s' % asset_str) self.assertEqual(res['result']['admin'], 'AWKECj9RD8rS8RPcpCgYVjk1DeYyHwxZm3') self.assertEqual(res['result']['available'], 0) @@ -238,8 +239,7 @@ def test_get_asset_state_hash(self): def test_get_asset_state_neo(self): asset_str = 'neo' req = self._gen_post_rpc_req("getassetstate", params=[asset_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['assetId'], '0x%s' % str(GetBlockchain().SystemShare().Hash)) self.assertEqual(res['result']['admin'], 'Abf2qMs1pzQb8kYk9RuxtUb9jtRKJVuBJt') self.assertEqual(res['result']['available'], 10000000000000000) @@ -247,8 +247,7 @@ def test_get_asset_state_neo(self): def test_get_asset_state_gas(self): asset_str = 'GAS' req = self._gen_post_rpc_req("getassetstate", params=[asset_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['assetId'], '0x%s' % str(GetBlockchain().SystemCoin().Hash)) self.assertEqual(res['result']['amount'], 10000000000000000) self.assertEqual(res['result']['admin'], 'AWKECj9RD8rS8RPcpCgYVjk1DeYyHwxZm3') @@ -256,47 +255,35 @@ def test_get_asset_state_gas(self): def test_get_asset_state_0x(self): asset_str = '0x602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282de7' req = self._gen_post_rpc_req("getassetstate", params=[asset_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['assetId'], asset_str) def test_bad_asset_state(self): asset_str = '602c79718b16e442de58778e148d0b1084e3b2dffd5de6b7b16cee7969282dee' req = self._gen_post_rpc_req("getassetstate", params=[asset_str]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown asset') def test_get_bestblockhash(self): req = self._gen_post_rpc_req("getbestblockhash", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], '0x0c9f39eddd425aba7c27543a90768093a90c76a35090ef9b413027927e887811') def test_get_connectioncount(self): # make sure we have a predictable state - NodeLeader.Reset() - leader = NodeLeader.Instance() - # old_leader = deepcopy(leader) - fake_obj = object() - leader.Peers = [fake_obj, fake_obj] - leader.KNOWN_ADDRS = [fake_obj, fake_obj] + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + nodemgr.nodes = [object(), object()] req = self._gen_post_rpc_req("getconnectioncount", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], 2) - - # restore whatever state the instance was in - # NodeLeader._LEAD = old_leader + nodemgr.reset_for_test() def test_get_block_int(self): req = self._gen_post_rpc_req("getblock", params=[10, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 10) self.assertEqual(res['result']['hash'], '0xd69e7a1f62225a35fed91ca578f33447d93fa0fd2b2f662b957e19c38c1dab1e') self.assertEqual(res['result']['confirmations'], GetBlockchain().Height - 10 + 1) @@ -304,50 +291,42 @@ def test_get_block_int(self): def test_get_block_hash(self): req = self._gen_post_rpc_req("getblock", params=['2b1c78633dae7ab81f64362e0828153079a17b018d779d0406491f84c27b086f', 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 11) self.assertEqual(res['result']['confirmations'], GetBlockchain().Height - 11 + 1) self.assertEqual(res['result']['previousblockhash'], '0xd69e7a1f62225a35fed91ca578f33447d93fa0fd2b2f662b957e19c38c1dab1e') def test_get_block_hash_0x(self): req = self._gen_post_rpc_req("getblock", params=['0x2b1c78633dae7ab81f64362e0828153079a17b018d779d0406491f84c27b086f', 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 11) def test_get_block_hash_failure(self): req = self._gen_post_rpc_req("getblock", params=['aad34f68cb7a04d625ae095fa509479ec7dcb4dc87ecd865ab059d0f8a42decf', 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown block') def test_get_block_sysfee(self): req = self._gen_post_rpc_req("getblocksysfee", params=[9479]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], 1560) # test negative block req = self._gen_post_rpc_req("getblocksysfee", params=[-1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Invalid Height') # test block exceeding max block height req = self._gen_post_rpc_req("getblocksysfee", params=[3000000000]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Invalid Height') def test_block_non_verbose(self): req = self._gen_post_rpc_req("getblock", params=[2003, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertIsNotNone(res['result']) # we should be able to instantiate a matching block with the result @@ -359,8 +338,7 @@ def test_block_non_verbose(self): def test_get_contract_state(self): contract_hash = "b9fbcff6e50fd381160b822207231233dd3c56c2" req = self._gen_post_rpc_req("getcontractstate", params=[contract_hash]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['code_version'], '') self.assertEqual(res['result']['properties']['storage'], True) self.assertEqual(res['result']['hash'], '0xb9fbcff6e50fd381160b822207231233dd3c56c2') @@ -370,72 +348,68 @@ def test_get_contract_state(self): def test_get_contract_state_0x(self): contract_hash = "0xb9fbcff6e50fd381160b822207231233dd3c56c2" req = self._gen_post_rpc_req("getcontractstate", params=[contract_hash]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['code_version'], '') def test_get_contract_state_not_found(self): contract_hash = '0xb9fbcff6e50fd381160b822207231233dd3c56c1' req = self._gen_post_rpc_req("getcontractstate", params=[contract_hash]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown contract') def test_get_raw_mempool(self): - # TODO: currently returns empty list. test with list would be great + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + + raw_tx = b'd100644011111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111081234567890abcdef0415cd5b0769cc4ee2f1c9f4e0782756dabf246d0a4fe60a035400000000' + tx = Helper.AsSerializableWithType(binascii.unhexlify(raw_tx), 'neo.Core.TX.InvocationTransaction.InvocationTransaction') + nodemgr.mempool.add_transaction(tx) + req = self._gen_post_rpc_req("getrawmempool", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) + mempool = res['result'] - # when running only these tests, mempool is empty. when running all tests, there are a - # number of entries - if len(mempool) > 0: - for entry in mempool: - self.assertEqual(entry[0:2], "0x") - self.assertEqual(len(entry), 66) + self.assertEqual(1, len(mempool)) + self.assertEqual(tx.Hash.To0xString(), mempool[0]) + nodemgr.reset_for_test() def test_get_version(self): - # TODO: what's the nonce? on testnet live server response it's always 771199013 + nodemgr = NodeManager() req = self._gen_post_rpc_req("getversion", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res["result"]["port"], 20332) + self.assertEqual(res["result"]["nonce"], nodemgr.id) self.assertEqual(res["result"]["useragent"], "/NEO-PYTHON:%s/" % __version__) def test_validate_address(self): # example from docs.neo.org req = self._gen_post_rpc_req("validateaddress", params=["AQVh2pG732YvtNaxEGkQUei3YA4cvo7d2i"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue(res["result"]["isvalid"]) # example from docs.neo.org req = self._gen_post_rpc_req("validateaddress", params=["152f1muMCNa7goXYhYAQC61hxEgGacmncB"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertFalse(res["result"]["isvalid"]) # catch completely invalid argument req = self._gen_post_rpc_req("validateaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual('Missing argument', res['error']['message']) # catch completely invalid argument req = self._gen_post_rpc_req("validateaddress", params=[""]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual('Missing argument', res['error']['message']) def test_getrawtx_1(self): txid = 'f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43a' req = self._gen_post_rpc_req("getrawtransaction", params=[txid, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req))['result'] + res = json.loads(self.do_test_post("/", json=req))['result'] self.assertEqual(res['blockhash'], '0x6088bf9d3b55c67184f60b00d2e380228f713b4028b24c1719796dcd2006e417') self.assertEqual(res['txid'], "0x%s" % txid) self.assertEqual(res['blocktime'], 1533756500) @@ -444,16 +418,14 @@ def test_getrawtx_1(self): def test_getrawtx_2(self): txid = 'f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43a' req = self._gen_post_rpc_req("getrawtransaction", params=[txid, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req))['result'] + res = json.loads(self.do_test_post("/", json=req))['result'] expected = '8000012023ba2703c53263e8d6e522dc32203339dcd8eee901ff6a846c115ef1fb88664b00aa67f2c95e9405286db1b56c9120c27c698490530000029b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc50010a5d4e8000000affb37f5fdb9c6fec48d9f0eee85af82950f9b4a9b7cffdaa674beae0f930ebe6085af9093e5fe56b34a5c220ccdcf6efc336fc500f01b9b0986230023ba2703c53263e8d6e522dc32203339dcd8eee9014140a88bd1fcfba334b06da0ce1a679f80711895dade50352074e79e438e142dc95528d04a00c579398cb96c7301428669a09286ae790459e05e907c61ab8a1191c62321031a6c6fbbdf02ca351745fa86b9ba5a9452d785ac4f7fc2b7548ca2a46c4fcf4aac' self.assertEqual(res, expected) def test_getrawtx_3(self): txid = 'f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43b' req = self._gen_post_rpc_req("getrawtransaction", params=[txid, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown Transaction') @@ -461,8 +433,7 @@ def test_get_storage_item(self): contract_hash = 'b9fbcff6e50fd381160b822207231233dd3c56c2' storage_key = binascii.hexlify(b'in_circulation').decode('utf-8') req = self._gen_post_rpc_req("getstorage", params=[contract_hash, storage_key]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], '00a031a95fe300') actual_val = int.from_bytes(binascii.unhexlify(res['result'].encode('utf-8')), 'little') self.assertEqual(actual_val, 250000000000000) @@ -471,46 +442,36 @@ def test_get_storage_item2(self): contract_hash = '90ea0b9b8716cf0ceca5b24f6256adf204f444d9' storage_key = binascii.hexlify(b'in_circulation').decode('utf-8') req = self._gen_post_rpc_req("getstorage", params=[contract_hash, storage_key]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], '00c06e31d91001') def test_get_storage_item_key_not_found(self): contract_hash = 'b9fbcff6e50fd381160b822207231233dd3c56c1' storage_key = binascii.hexlify(b'blah').decode('utf-8') req = self._gen_post_rpc_req("getstorage", params=[contract_hash, storage_key]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], None) def test_get_storage_item_contract_not_found(self): contract_hash = 'b9fbcff6e50fd381160b822207231233dd3c56c1' storage_key = binascii.hexlify(b'blah').decode('utf-8') req = self._gen_post_rpc_req("getstorage", params=[contract_hash, storage_key]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result'], None) def test_get_storage_item_bad_contract_hash(self): contract_hash = 'b9fbcff6e50f01160b822207231233dd3c56c1' storage_key = binascii.hexlify(b'blah').decode('utf-8') req = self._gen_post_rpc_req("getstorage", params=[contract_hash, storage_key]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertIn('Invalid UInt', res['error']['message']) - def test_get_unspents(self): - u = UInt256.ParseString('f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43a') - unspents = GetBlockchain().GetAllUnspent(u) - self.assertEqual(len(unspents), 1) - def test_gettxout(self): txid = 'a2a37fd2ab7048d70d51eaa8af2815e0e542400329b05a34274771174180a7e8' output_index = 0 req = self._gen_post_rpc_req("gettxout", params=[txid, output_index]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) # will return `null` if not found self.assertEqual(None, res["result"]) @@ -519,8 +480,7 @@ def test_gettxout(self): txid = '42978cd563e9e95550fb51281d9071e27ec94bd42116836f0d0141d57a346b3e' output_index = 1 req = self._gen_post_rpc_req("gettxout", params=[txid, output_index]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) expected_asset = '0xc56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b' expected_value = "99989900" @@ -535,110 +495,105 @@ def test_gettxout(self): txid = 'f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43a' output_index = 0 req = self._gen_post_rpc_req("gettxout", params=[txid, output_index]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) expected_value = "10000" self.assertEqual(output_index, res["result"]["n"]) self.assertEqual(expected_value, res["result"]["value"]) def test_send_raw_tx(self): - raw_tx = '8000000001e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c6000ca9a3b0000000048033b58ef547cbf54c8ee2f72a42d5b603c00af' - req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res['result'], True) + + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object(), object())] + + raw_tx = '8000000001e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c6000ca9a3b0000000048033b58ef547cbf54c8ee2f72a42d5b603c00af' + req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) + res = json.loads(self.do_test_post("/", json=req)) + self.assertEqual(res['result'], True) + nodemgr.reset_for_test() def test_send_raw_tx_bad(self): - raw_tx = '80000001b10ad9ec660bf343c0eb411f9e05b4fa4ad8abed31d4e4dc5bb6ae416af0c4de000002e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c60c8db571300000000af12a8687b14948bc4a008128a550a63695bc1a5e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603808b44002000000eca8fcf94e7a2a7fc3fd54ae0ed3d34d52ec25900141404749ce868ed9588f604eeeb5c523db39fd57cd7f61d04393a1754c2d32f131d67e6b1ec561ac05012b7298eb5ff254487c76de0b2a0c4d097d17cec708c0a9802321025b5c8cdcb32f8e278e111a0bf58ebb463988024bb4e250aa4310b40252030b60ac' - req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertEqual(res['result'], False) + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object(), object())] - def test_send_raw_tx_bad_2(self): - raw_tx = '80000001b10ad9ec660bf343c0eb411f9e05b4fa4ad8abed31d4e4dc5bb6ae416af0c4de000002e72d286979ee6cbb7e65dfddfb2e384100b8d148e7758de42e4168b71792c60c8db571300000000af12a8687b14948bc4a008128a550a63695bc1a5e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603808b44002000000eca8fcf94e7a2a7fc3fd54ae0ed3d34d52ec25900141404749ce868ed9588f604eeeb5c523db39fd57cd7f61d04393a1754c2d32f131d67e6b1ec561ac05012b7298eb5ff254487c76de0b2a0c4d097d17cec708c0a9802321025b5c8cdcb32f8e278e111a0bf58ebb463988024bb4e250aa4310b40252030b60ac' - req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - self.assertTrue('error' in res) - self.assertEqual(res['error']['code'], -32603) + raw_tx = '80000001b10ad9ec660bf343c0eb411f9e05b4fa4ad8abed31d4e4dc5bb6ae416af0c4de000002e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c60c8db571300000000af12a8687b14948bc4a008128a550a63695bc1a5e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603808b44002000000eca8fcf94e7a2a7fc3fd54ae0ed3d34d52ec25900141404749ce868ed9588f604eeeb5c523db39fd57cd7f61d04393a1754c2d32f131d67e6b1ec561ac05012b7298eb5ff254487c76de0b2a0c4d097d17cec708c0a9802321025b5c8cdcb32f8e278e111a0bf58ebb463988024bb4e250aa4310b40252030b60ac' + req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) + res = json.loads(self.do_test_post("/", json=req)) + self.assertEqual(res['result'], False) + nodemgr.reset_for_test() + def test_send_raw_tx_bad_2(self): + with patch('neo.Network.node.NeoNode.relay', return_value=self.async_return(True)): + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + nodemgr.nodes = [NeoNode(object(), object())] + + raw_tx = '80000001b10ad9ec660bf343c0eb411f9e05b4fa4ad8abed31d4e4dc5bb6ae416af0c4de000002e72d286979ee6cbb7e65dfddfb2e384100b8d148e7758de42e4168b71792c60c8db571300000000af12a8687b14948bc4a008128a550a63695bc1a5e72d286979ee6cb1b7e65dfddfb2e384100b8d148e7758de42e4168b71792c603808b44002000000eca8fcf94e7a2a7fc3fd54ae0ed3d34d52ec25900141404749ce868ed9588f604eeeb5c523db39fd57cd7f61d04393a1754c2d32f131d67e6b1ec561ac05012b7298eb5ff254487c76de0b2a0c4d097d17cec708c0a9802321025b5c8cdcb32f8e278e111a0bf58ebb463988024bb4e250aa4310b40252030b60ac' + req = self._gen_post_rpc_req("sendrawtransaction", params=[raw_tx]) + res = json.loads(self.do_test_post("/", json=req)) + self.assertTrue('error' in res) + self.assertEqual(res['error']['code'], -32603) + nodemgr.reset_for_test() + + @SkipTest def test_gzip_compression(self): - req = self._gen_post_rpc_req("getblock", params=['307ed2cf8b8935dd38c534b10dceac55fcd0f60c68bf409627f6c155f8143b31', 1]) + # TODO: figure out how to properly validate gzip with aiohttp + # note it is applied using the @json decorator in neo/api/utils.py + req = self._gen_post_rpc_req("getblock", params=['0x2b1c78633dae7ab81f64362e0828153079a17b018d779d0406491f84c27b086f', 1]) body = json.dumps(req).encode("utf-8") - # first validate that we get a gzip response if we accept gzip encoding - mock_req = requestMock(path=b'/', method=b"POST", body=body, headers={'Accept-Encoding': ['deflate', 'gzip;q=1.0', '*;q=0.5']}) - res = self.app.home(mock_req) - - GZIP_MAGIC = b'\x1f\x8b' - self.assertIsInstance(res, bytes) - self.assertTrue(res.startswith(GZIP_MAGIC)) + async def test_get_route(url, data=None, headers=None): + resp = await self.client.post(url, json=data, headers=headers) + return resp - # then validate that we don't get a gzip response if we don't accept gzip encoding - mock_req = requestMock(path=b'/', method=b"POST", body=body, headers={}) - res = self.app.home(mock_req) - - self.assertIsInstance(res, str) + # first validate that we get a gzip response if we accept gzip encoding + resp = self.loop.run_until_complete(test_get_route("/", headers={'Accept-Encoding': "deflate, gzip;q=1.0, *;q=0.5"})) + self.assertEqual(83, resp.content_length) - try: - json.loads(res) - valid_json = True - except ValueError: - valid_json = False - self.assertTrue(valid_json) + resp = self.loop.run_until_complete(test_get_route("/", headers={'Accept-Encoding': ""})) + self.assertEqual(2283, resp.content_length) def test_getpeers(self): - # Given this is an isolated environment and there is no peers + # Given this is an isolated environment and there are no peers # let's simulate that at least some addresses are known - node = NodeLeader.Instance() - node.KNOWN_ADDRS = ["127.0.0.1:20333", "127.0.0.2:20334"] - node.DEAD_ADDRS = ["127.0.0.1:20335"] - test_node = NeoNode() - test_node.host = "127.0.0.1" - test_node.port = 20333 - node.Peers = [test_node] + nodemgr = self.api_server.nodemgr + nodemgr.reset_for_test() + node1 = NeoNode(object, object) + node1.address = "127.0.0.1:2222" + nodemgr.nodes.append(node1) + + nodemgr.known_addresses = ["127.0.0.1:20333", "127.0.0.2:20334"] + nodemgr.bad_addresses = ["127.0.0.1:20335"] req = self._gen_post_rpc_req("getpeers", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - - self.assertEqual(len(node.Peers), len(res['result']['connected'])) - print("unconnected:{}".format(len(res['result']['unconnected']))) - print("addrs:{} peers:{}".format(len(node.KNOWN_ADDRS), len(node.Peers))) - self.assertEqual(len(res['result']['unconnected']), - len(node.KNOWN_ADDRS) - len(node.Peers)) - self.assertEqual(len(res['result']['bad']), 1) - # To avoid messing up the next tests - node.Peers = [] - node.KNOWN_ADDRS = [] - node.DEAD_ADDRS = [] + res = json.loads(self.do_test_post("/", json=req)) + + self.assertEqual(1, len(res['result']['connected'])) + self.assertEqual(2, len(res['result']['unconnected'])) + self.assertEqual(1, len(res['result']['bad'])) def test_getwalletheight_no_wallet(self): req = self._gen_post_rpc_req("getwalletheight", params=["some id here"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) self.assertEqual(error.get('message', None), "Access denied.") def test_getwalletheight(self): - self.app.wallet = UserWallet.Open(os.path.join(ROOT_INSTALL_PATH, "neo/data/neo-privnet.sample.wallet"), to_aes_key("coz")) + self.api_server.wallet = UserWallet.Open(os.path.join(ROOT_INSTALL_PATH, "neo/data/neo-privnet.sample.wallet"), to_aes_key("coz")) req = self._gen_post_rpc_req("getwalletheight", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(1, res.get('result')) def test_getbalance_no_wallet(self): req = self._gen_post_rpc_req("getbalance", params=["some id here"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) @@ -649,53 +604,48 @@ def test_getbalance_neo_with_wallet(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) neo_id = "c56f33fc6ecfcd0c225c4ab356fee59390af8560be0e930faebe74a6daff7c9b" req = self._gen_post_rpc_req("getbalance", params=[neo_id]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertIn('Balance', res.get('result').keys()) self.assertEqual(res['result']['Balance'], "150") self.assertIn('Confirmed', res.get('result').keys()) self.assertEqual(res['result']['Confirmed'], "50.0") - self.app.wallet.Close() - self.app.wallet = None - os.remove(WalletFixtureTestCase.wallet_1_dest()) + self.api_server.wallet.Close() + self.api_server.wallet = None + os.remove(test_wallet_path) def test_getbalance_token_with_wallet(self): test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_2_path(), WalletFixtureTestCase.wallet_2_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_2_pass()) ) fake_token_id = "NXT4" req = self._gen_post_rpc_req("getbalance", params=[fake_token_id]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertIn('Balance', res.get('result').keys()) self.assertEqual(res['result']['Balance'], "1000") self.assertNotIn('Confirmed', res.get('result').keys()) - self.app.wallet.Close() - self.app.wallet = None - os.remove(WalletFixtureTestCase.wallet_2_dest()) + self.api_server.wallet.Close() + self.api_server.wallet = None + os.remove(test_wallet_path) def test_listaddress_no_wallet(self): req = self._gen_post_rpc_req("listaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) @@ -703,26 +653,24 @@ def test_listaddress_no_wallet(self): def test_listaddress_with_wallet(self): test_wallet_path = os.path.join(mkdtemp(), "listaddress.db3") - self.app.wallet = UserWallet.Create( + self.api_server.wallet = UserWallet.Create( test_wallet_path, to_aes_key('awesomepassword') ) req = self._gen_post_rpc_req("listaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) results = res.get('result', []) self.assertGreater(len(results), 0) self.assertIn(results[0].get('address', None), - self.app.wallet.Addresses) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Addresses) + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(test_wallet_path) def test_getnewaddress_no_wallet(self): req = self._gen_post_rpc_req("getnewaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) @@ -731,23 +679,22 @@ def test_getnewaddress_no_wallet(self): def test_getnewaddress_with_wallet(self): test_wallet_path = os.path.join(mkdtemp(), "getnewaddress.db3") - self.app.wallet = UserWallet.Create( + self.api_server.wallet = UserWallet.Create( test_wallet_path, to_aes_key('awesomepassword') ) - old_addrs = self.app.wallet.Addresses + old_addrs = self.api_server.wallet.Addresses req = self._gen_post_rpc_req("getnewaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) result = res.get('result') self.assertNotIn(result, old_addrs) - self.assertIn(result, self.app.wallet.Addresses) + self.assertIn(result, self.api_server.wallet.Addresses) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(test_wallet_path) def test_valid_multirequest(self): @@ -756,8 +703,7 @@ def test_valid_multirequest(self): verbose_block_request = {"jsonrpc": "2.0", "method": "getblock", "params": [1, 1], "id": 2} multi_request = json.dumps([raw_block_request, verbose_block_request]) - mock_req = mock_post_request(multi_request.encode()) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", data=multi_request)) self.assertEqual(type(res), list) self.assertEqual(len(res), 2) @@ -767,10 +713,8 @@ def test_valid_multirequest(self): self.assertEqual(res[1]['result']['hash'], expected_verbose_hash) # test GET requests ...should fail - raw_request = b"/?[jsonrpc=2.0&method=getblock¶ms=[1]&id=1,jsonrpc=2.0&method=getblock¶ms=[1,1]&id=2]" - - mock_req = mock_get_request(raw_request) - res = json.loads(self.app.home(mock_req)) + raw_request = "/?[jsonrpc=2.0&method=getblock¶ms=[1]&id=1,jsonrpc=2.0&method=getblock¶ms=[1,1]&id=2]" + res = json.loads(self.do_test_get(raw_request)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32600) @@ -785,8 +729,7 @@ def test_multirequest_with_1_invalid_request(self): verbose_block_request = {"jsonrpc": "2.0", "method": "getblock", "params": [1, 1], "id": 2} multi_request = json.dumps([raw_block_request, verbose_block_request]) - mock_req = mock_post_request(multi_request.encode()) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", data=multi_request)) self.assertEqual(type(res), list) self.assertEqual(len(res), 2) @@ -802,9 +745,7 @@ def test_multirequest_with_1_invalid_request(self): def test_send_to_address_no_wallet(self): req = self._gen_post_rpc_req("sendtoaddress", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) @@ -812,22 +753,20 @@ def test_send_to_address_no_wallet(self): def test_send_to_address_wrong_arguments(self): test_wallet_path = os.path.join(mkdtemp(), "sendtoaddress.db3") - self.app.wallet = UserWallet.Create( + self.api_server.wallet = UserWallet.Create( test_wallet_path, to_aes_key('awesomepassword') ) req = self._gen_post_rpc_req("sendtoaddress", params=["arg"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(test_wallet_path) def test_send_to_address_simple(self): @@ -835,22 +774,21 @@ def test_send_to_address_simple(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) self.assertIn('vin', res.get('result', {}).keys()) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_with_fee(self): @@ -858,22 +796,21 @@ def test_send_to_address_with_fee(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['neo', address, 1, 0.005]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) self.assertEqual(res['result']['net_fee'], "0.005") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_bad_assetid(self): @@ -881,22 +818,21 @@ def test_send_to_address_bad_assetid(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['ga', address, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_bad_address(self): @@ -904,22 +840,20 @@ def test_send_to_address_bad_address(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaX' # "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaX" is too short causing ToScriptHash to fail req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_negative_amount(self): @@ -927,22 +861,21 @@ def test_send_to_address_negative_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, -1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_zero_amount(self): @@ -950,22 +883,21 @@ def test_send_to_address_zero_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_negative_fee(self): @@ -973,22 +905,21 @@ def test_send_to_address_negative_fee(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 1, -0.005]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_insufficient_funds(self): @@ -996,101 +927,96 @@ def test_send_to_address_insufficient_funds(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 51]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -300) self.assertEqual(error.get('message', None), "Insufficient funds") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_to_address_fails_to_sign_tx(self): - with patch('neo.api.JSONRPC.JsonRpcApi.Wallet.Sign', return_value='False'): + with patch('neo.Implementations.Wallets.peewee.UserWallet.UserWallet.Sign', return_value='False'): test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' req = self._gen_post_rpc_req("sendtoaddress", params=['gas', address, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) - + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('type', res.get('result', {}).keys()) self.assertIn('hex', res.get('result', {}).keys()) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_no_wallet(self): req = self._gen_post_rpc_req("sendfrom", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) self.assertEqual(error.get('message', None), "Access denied.") def test_send_from_wrong_arguments(self): test_wallet_path = os.path.join(mkdtemp(), "sendfromaddress.db3") - self.app.wallet = UserWallet.Create( + self.api_server.wallet = UserWallet.Create( test_wallet_path, to_aes_key('awesomepassword') ) req = self._gen_post_rpc_req("sendfrom", params=["arg"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(test_wallet_path) def test_send_from_simple(self): + self.api_server.nodemgr.reset_for_test() test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) self.assertIn('vin', res.get('result', {}).keys()) self.assertEqual(address_to, res['result']['vout'][0]['address']) self.assertEqual(address_from, res['result']['vout'][1]['address']) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_complex(self): + self.api_server.nodemgr.reset_for_test() test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1105,8 +1031,7 @@ def test_send_from_complex(self): address_from_gas_bal = address_from_gas['value'] req = self._gen_post_rpc_req("sendfrom", params=['gas', address_from, address_to, amount, net_fee, change_addr]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) @@ -1115,8 +1040,8 @@ def test_send_from_complex(self): self.assertEqual(float(address_from_gas_bal) - amount - net_fee, float(res['result']['vout'][1]['value'])) self.assertEqual(res['result']['net_fee'], "0.005") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_bad_assetid(self): @@ -1124,20 +1049,19 @@ def test_send_from_bad_assetid(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['nep', address_from, address_to, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_negative_amount(self): @@ -1145,20 +1069,19 @@ def test_send_from_negative_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, -1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_zero_amount(self): @@ -1166,20 +1089,19 @@ def test_send_from_zero_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_bad_from_addr(self): @@ -1187,20 +1109,19 @@ def test_send_from_bad_from_addr(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc' # "AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc" is too short causing ToScriptHash to fail req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_bad_to_addr(self): @@ -1208,20 +1129,19 @@ def test_send_from_bad_to_addr(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaX' # "AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaX" is too short causing ToScriptHash to fail address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_negative_fee(self): @@ -1229,20 +1149,19 @@ def test_send_from_negative_fee(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1, -0.005]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_bad_change_addr(self): @@ -1250,20 +1169,19 @@ def test_send_from_bad_change_addr(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1, .005, 'AGYaEi3W6ndHPUmW7T12FFfsbQ6DWymkE']) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_insufficient_funds(self): @@ -1271,61 +1189,58 @@ def test_send_from_insufficient_funds(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 51]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -300) self.assertEqual(error.get('message', None), "Insufficient funds") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_send_from_fails_to_sign_tx(self): - with patch('neo.api.JSONRPC.JsonRpcApi.Wallet.Sign', return_value='False'): + with patch('neo.Implementations.Wallets.peewee.UserWallet.UserWallet.Sign', return_value='False'): test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) address_to = 'AXjaFSP23Jkbe6Pk9pPGT6NBDs1HVdqaXK' address_from = 'AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3' req = self._gen_post_rpc_req("sendfrom", params=['neo', address_from, address_to, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('type', res.get('result', {}).keys()) self.assertIn('hex', res.get('result', {}).keys()) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_no_wallet(self): req = self._gen_post_rpc_req("sendmany", params=[]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -400) self.assertEqual(error.get('message', None), "Access denied.") - def test_sendmany_complex(self): + def test_sendmany_complex_post(self): # test POST requests test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1337,8 +1252,7 @@ def test_sendmany_complex(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output, 1, "APRgMZHZubii29UXF9uFa6sohrsYupNAvx"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) @@ -1352,8 +1266,8 @@ def test_sendmany_complex(self): transfers += 1 self.assertEqual(2, transfers) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) # test GET requests @@ -1361,13 +1275,12 @@ def test_sendmany_complex(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) req = self._gen_get_rpc_req("sendmany", params=[output, 0.005, "APRgMZHZubii29UXF9uFa6sohrsYupNAvx"]) - mock_req = mock_get_request(req) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_get(req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) @@ -1381,8 +1294,8 @@ def test_sendmany_complex(self): transfers += 1 self.assertEqual(2, transfers) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_min_params(self): @@ -1390,7 +1303,7 @@ def test_sendmany_min_params(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1402,30 +1315,28 @@ def test_sendmany_min_params(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('txid', res.get('result', {}).keys()) self.assertIn('vin', res.get('result', {}).keys()) self.assertIn("AJQ6FoaSXDFzA6wLnyZ1nFN7SGSN2oNTc3", res['result']['vout'][2]['address']) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_not_list(self): test_wallet_path = os.path.join(mkdtemp(), "sendfromaddress.db3") - self.app.wallet = UserWallet.Create( + self.api_server.wallet = UserWallet.Create( test_wallet_path, to_aes_key('awesomepassword') ) req = self._gen_post_rpc_req("sendmany", params=["arg"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(test_wallet_path) def test_sendmany_too_many_args(self): @@ -1433,7 +1344,7 @@ def test_sendmany_too_many_args(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1445,13 +1356,12 @@ def test_sendmany_too_many_args(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output, 1, "APRgMZHZubii29UXF9uFa6sohrsYupNAvx", "arg"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_bad_assetid(self): @@ -1459,7 +1369,7 @@ def test_sendmany_bad_assetid(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1471,13 +1381,12 @@ def test_sendmany_bad_assetid(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_bad_address(self): @@ -1485,7 +1394,7 @@ def test_sendmany_bad_address(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1497,13 +1406,12 @@ def test_sendmany_bad_address(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_negative_amount(self): @@ -1511,7 +1419,7 @@ def test_sendmany_negative_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1523,13 +1431,12 @@ def test_sendmany_negative_amount(self): "value": -1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_zero_amount(self): @@ -1537,7 +1444,7 @@ def test_sendmany_zero_amount(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1549,13 +1456,12 @@ def test_sendmany_zero_amount(self): "value": 0, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_negative_fee(self): @@ -1563,7 +1469,7 @@ def test_sendmany_negative_fee(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1575,13 +1481,12 @@ def test_sendmany_negative_fee(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output, -0.005, "APRgMZHZubii29UXF9uFa6sohrsYupNAvx"]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_bad_change_address(self): @@ -1589,7 +1494,7 @@ def test_sendmany_bad_change_address(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1602,13 +1507,12 @@ def test_sendmany_bad_change_address(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output, 0.005, change_addr]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -32602) self.assertEqual(error.get('message', None), "Invalid params") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_insufficient_funds(self): @@ -1616,7 +1520,7 @@ def test_sendmany_insufficient_funds(self): WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1628,22 +1532,21 @@ def test_sendmany_insufficient_funds(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) error = res.get('error', {}) self.assertEqual(error.get('code', None), -300) self.assertEqual(error.get('message', None), "Insufficient funds") - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_sendmany_fails_to_sign_tx(self): - with patch('neo.api.JSONRPC.JsonRpcApi.Wallet.Sign', return_value='False'): + with patch('neo.Implementations.Wallets.peewee.UserWallet.UserWallet.Sign', return_value='False'): test_wallet_path = shutil.copyfile( WalletFixtureTestCase.wallet_1_path(), WalletFixtureTestCase.wallet_1_dest() ) - self.app.wallet = UserWallet.Open( + self.api_server.wallet = UserWallet.Open( test_wallet_path, to_aes_key(WalletFixtureTestCase.wallet_1_pass()) ) @@ -1655,19 +1558,17 @@ def test_sendmany_fails_to_sign_tx(self): "value": 1, "address": address_to}] req = self._gen_post_rpc_req("sendmany", params=[output]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res.get('jsonrpc', None), '2.0') self.assertIn('type', res.get('result', {}).keys()) self.assertIn('hex', res.get('result', {}).keys()) - self.app.wallet.Close() - self.app.wallet = None + self.api_server.wallet.Close() + self.api_server.wallet = None os.remove(WalletFixtureTestCase.wallet_1_dest()) def test_getblockheader_int(self): req = self._gen_post_rpc_req("getblockheader", params=[10, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 10) self.assertEqual(res['result']['hash'], '0xd69e7a1f62225a35fed91ca578f33447d93fa0fd2b2f662b957e19c38c1dab1e') self.assertEqual(res['result']['confirmations'], GetBlockchain().Height - 10 + 1) @@ -1675,8 +1576,7 @@ def test_getblockheader_int(self): def test_getblockheader_hash(self): req = self._gen_post_rpc_req("getblockheader", params=['2b1c78633dae7ab81f64362e0828153079a17b018d779d0406491f84c27b086f', 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 11) self.assertEqual(res['result']['confirmations'], GetBlockchain().Height - 11 + 1) @@ -1684,21 +1584,18 @@ def test_getblockheader_hash(self): def test_getblockheader_hash_0x(self): req = self._gen_post_rpc_req("getblockheader", params=['0x2b1c78633dae7ab81f64362e0828153079a17b018d779d0406491f84c27b086f', 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(res['result']['index'], 11) def test_getblockheader_hash_failure(self): req = self._gen_post_rpc_req("getblockheader", params=[GetBlockchain().Height + 1, 1]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown block') def test_getblockheader_non_verbose(self): req = self._gen_post_rpc_req("getblockheader", params=[11, 0]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertIsNotNone(res['result']) # we should be able to instantiate a matching block with the result @@ -1710,22 +1607,19 @@ def test_getblockheader_non_verbose(self): def test_gettransactionheight(self): txid = 'f999c36145a41306c846ea80290416143e8e856559818065be3f4e143c60e43a' req = self._gen_post_rpc_req("gettransactionheight", params=[txid]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertEqual(9448, res['result']) def test_gettransactionheight_invalid_hash(self): txid = 'invalid_tx_id' req = self._gen_post_rpc_req("gettransactionheight", params=[txid]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown transaction') def test_gettransactionheight_invalid_hash2(self): txid = 'a' * 64 # something the right length but unknown req = self._gen_post_rpc_req("gettransactionheight", params=[txid]) - mock_req = mock_post_request(json.dumps(req).encode("utf-8")) - res = json.loads(self.app.home(mock_req)) + res = json.loads(self.do_test_post("/", json=req)) self.assertTrue('error' in res) self.assertEqual(res['error']['message'], 'Unknown transaction') diff --git a/neo/api/REST/RestApi.py b/neo/api/REST/RestApi.py index caaee5a2c..671c1a365 100644 --- a/neo/api/REST/RestApi.py +++ b/neo/api/REST/RestApi.py @@ -1,37 +1,43 @@ """ -The REST API is using the Python package 'klein', which makes it possible to -create HTTP routes and handlers with Twisted in a similar style to Flask: -https://github.com/twisted/klein - +The REST API is using the Python package 'aioHttp' """ -import json -from klein import Klein -from logzero import logger +import math -from neo.Network.NodeLeader import NodeLeader -from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB -from neo.Core.Blockchain import Blockchain +from aiohttp import web +from logzero import logger from neo.Core.UInt160 import UInt160 from neo.Core.UInt256 import UInt256 + +from neo.Core.Blockchain import Blockchain +from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB +from neo.Network.nodemanager import NodeManager from neo.Settings import settings -from neo.api.utils import cors_header -import math +from neo.api.utils import json_response API_URL_PREFIX = "/v1" class RestApi: - app = Klein() notif = None def __init__(self): self.notif = NotificationDB.instance() + self.app = web.Application() + self.app.add_routes([ + web.route('*', '/', self.home), + web.get("/v1/notifications/block/{block}", self.get_by_block), + web.get("/v1/notifications/addr/{address}", self.get_by_addr), + web.get("/v1/notifications/tx/{tx_hash}", self.get_by_tx), + web.get("/v1/notifications/contract/{contract_hash}", self.get_by_contract), + web.get("/v1/token/{contract_hash}", self.get_token), + web.get("/v1/tokens", self.get_tokens), + web.get("/v1/status", self.get_status) + ]) # # REST API Routes # - @app.route('/') - def home(self, request): + async def home(self, request): endpoints_html = """
  • {apiPrefix}/notifications/block/<height>
    notifications by block
  • {apiPrefix}/notifications/addr/<addr>
    notifications by address
  • @@ -43,7 +49,7 @@ def home(self, request):
""".format(apiPrefix=API_URL_PREFIX) - return """ + out = """

@@ -118,37 +124,35 @@ def home(self, request): """ % (settings.net_name, endpoints_html) + return web.Response(text=out, content_type="text/html") - @app.route('%s/notifications/block/' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_by_block(self, request, block): - request.setHeader('Content-Type', 'application/json') + @json_response + async def get_by_block(self, request): try: + block = request.match_info['block'] if int(block) > Blockchain.Default().Height: return self.format_message("Higher than current block") else: - notifications = self.notif.get_by_block(block) + notifications = self.notif.get_by_block(int(block)) except Exception as e: logger.info("Could not get notifications for block %s %s" % (block, e)) return self.format_message("Could not get notifications for block %s because %s " % (block, e)) - return self.format_notifications(request, notifications) + x = self.format_notifications(request, notifications) + return x - @app.route('%s/notifications/addr/' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_by_addr(self, request, address): - request.setHeader('Content-Type', 'application/json') + @json_response + async def get_by_addr(self, request): try: + address = request.match_info['address'] notifications = self.notif.get_by_addr(address) except Exception as e: logger.info("Could not get notifications for address %s " % address) return self.format_message("Could not get notifications for address %s because %s" % (address, e)) return self.format_notifications(request, notifications) - @app.route('%s/notifications/tx/' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_by_tx(self, request, tx_hash): - request.setHeader('Content-Type', 'application/json') - + @json_response + async def get_by_tx(self, request): + tx_hash = request.match_info['tx_hash'] bc = Blockchain.Default() # type: Blockchain notifications = [] try: @@ -166,10 +170,9 @@ def get_by_tx(self, request, tx_hash): return self.format_notifications(request, notifications) - @app.route('%s/notifications/contract/' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_by_contract(self, request, contract_hash): - request.setHeader('Content-Type', 'application/json') + @json_response + async def get_by_contract(self, request): + contract_hash = request.match_info['contract_hash'] try: hash = UInt160.ParseString(contract_hash) notifications = self.notif.get_by_contract(hash) @@ -178,17 +181,14 @@ def get_by_contract(self, request, contract_hash): return self.format_message("Could not get notifications for contract hash %s because %s" % (contract_hash, e)) return self.format_notifications(request, notifications) - @app.route('%s/tokens' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_tokens(self, request): - request.setHeader('Content-Type', 'application/json') + @json_response + async def get_tokens(self, request): notifications = self.notif.get_tokens() return self.format_notifications(request, notifications) - @app.route('%s/token/' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_token(self, request, contract_hash): - request.setHeader('Content-Type', 'application/json') + @json_response + async def get_token(self, request): + contract_hash = request.match_info['contract_hash'] try: uint160 = UInt160.ParseString(contract_hash) contract_event = self.notif.get_token(uint160) @@ -201,15 +201,13 @@ def get_token(self, request, contract_hash): return self.format_notifications(request, notifications) - @app.route('%s/status' % API_URL_PREFIX, methods=['GET']) - @cors_header - def get_status(self, request): - request.setHeader('Content-Type', 'application/json') - return json.dumps({ - 'current_height': Blockchain.Default().Height + 1, + @json_response + async def get_status(self, request): + return { + 'current_height': Blockchain.Default().Height, 'version': settings.VERSION_NAME, - 'num_peers': len(NodeLeader.Instance().Peers) - }, indent=4, sort_keys=True) + 'num_peers': len(NodeManager().nodes) + } def format_notifications(self, request, notifications, show_none=False): @@ -217,14 +215,14 @@ def format_notifications(self, request, notifications, show_none=False): page_len = 500 page = 1 message = '' - if b'page' in request.args: + if 'page' in request.query: try: - page = int(request.args[b'page'][0]) + page = int(request.query['page']) except Exception as e: print("could not get page: %s" % e) - if b'pagesize' in request.args: + if 'pagesize' in request.query: try: - page_len = int(request.args[b'pagesize'][0]) + page_len = int(request.query['pagesize']) except Exception as e: print("could not get page length: %s" % e) @@ -243,23 +241,23 @@ def format_notifications(self, request, notifications, show_none=False): notifications = notifications[start:end] total_pages = math.ceil(notif_len / page_len) - return json.dumps({ - 'current_height': Blockchain.Default().Height + 1, + return { + 'current_height': Blockchain.Default().Height, 'message': message, 'total': notif_len, 'results': None if show_none else [n.ToJson() for n in notifications], 'page': page, 'page_len': page_len, 'total_pages': total_pages - }, indent=4, sort_keys=True) + } def format_message(self, message): - return json.dumps({ - 'current_height': Blockchain.Default().Height + 1, + return { + 'current_height': Blockchain.Default().Height, 'message': message, 'total': 0, 'results': None, 'page': 0, 'page_len': 0, 'total_pages': 0 - }, indent=4, sort_keys=True) + } diff --git a/neo/api/REST/test_rest_api.py b/neo/api/REST/test_rest_api.py index c5e20b017..15829fda9 100644 --- a/neo/api/REST/test_rest_api.py +++ b/neo/api/REST/test_rest_api.py @@ -1,56 +1,62 @@ -from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase -from neo.Settings import settings import json import os -import requests -import tarfile -import shutil -from neo.api.REST.RestApi import RestApi +from aiohttp.test_utils import AioHTTPTestCase from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB -from klein.test.test_resource import requestMock +from neo.Settings import settings +from neo.Utils.BlockchainFixtureTestCase import BlockchainFixtureTestCase +from neo.api.REST.RestApi import RestApi + + +class NotificationDBTestCase(BlockchainFixtureTestCase, AioHTTPTestCase): + + def __init__(self, *args, **kwargs): + super(NotificationDBTestCase, self).__init__(*args, **kwargs) + + async def get_application(self): + """ + Override the get_app method to return your application. + """ + self.api_server = RestApi() + return self.api_server.app + def do_test_get(self, url, data=None): + async def test_get_route(url, data=None): + resp = await self.client.get(url, data=data) + text = await resp.text() + return text -class NotificationDBTestCase(BlockchainFixtureTestCase): - app = None # type:RestApi + return self.loop.run_until_complete(test_get_route(url, data)) @classmethod def leveldb_testpath(cls): + super(NotificationDBTestCase, cls).leveldb_testpath() return os.path.join(settings.DATA_DIR_PATH, 'fixtures/test_chain') - def setUp(self): - self.app = RestApi() - def test_1_ok(self): - ndb = NotificationDB.instance() events = ndb.get_by_block(9583) self.assertEqual(len(events), 1) - def test_2_klein_app(self): - - self.assertIsNotNone(self.app.notif) + def test_2_app_server(self): + self.assertIsNotNone(self.api_server.notif) def test_3_index(self): - - mock_req = requestMock(path=b'/') - res = self.app.home(mock_req) + res = self.do_test_get("/") self.assertIn('endpoints', res) def test_4_by_block(self): - mock_req = requestMock(path=b'/block/9583') - res = self.app.get_by_block(mock_req, 9583) + res = self.do_test_get("/v1/notifications/block/9583") jsn = json.loads(res) self.assertEqual(jsn['total'], 1) results = jsn['results'] self.assertEqual(len(results), 1) def test_5_block_no_results(self): - mock_req = requestMock(path=b'/block/206') - res = self.app.get_by_block(mock_req, 206) + res = self.do_test_get("/v1/notifications/block/206") jsn = json.loads(res) self.assertEqual(jsn['total'], 0) results = jsn['results'] @@ -58,8 +64,7 @@ def test_5_block_no_results(self): self.assertEqual(len(results), 0) def test_6_block_num_too_big(self): - mock_req = requestMock(path=b'/block/2060200054055066') - res = self.app.get_by_block(mock_req, 2060200054055066) + res = self.do_test_get("/v1/notifications/block/2060200054055066") jsn = json.loads(res) self.assertEqual(jsn['total'], 0) results = jsn['results'] @@ -67,16 +72,14 @@ def test_6_block_num_too_big(self): self.assertIn('Higher than current block', jsn['message']) def test_7_by_addr(self): - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 500) def test_8_bad_addr(self): - mock_req = requestMock(path=b'/addr/AcFnRrVC5emrTEkuFuRPufcuTb6KsAJ3v') - res = self.app.get_by_addr(mock_req, 'AcFnRrVC5emrTEkuFuRPufcuTb6KsAJ3v') + res = self.do_test_get("/v1/notifications/addr/AcFnRrVC5emrTEkuFuRPufcuTb6KsAJ3v") jsn = json.loads(res) self.assertEqual(jsn['total'], 0) results = jsn['results'] @@ -84,16 +87,14 @@ def test_8_bad_addr(self): self.assertIn('Could not get notifications', jsn['message']) def test_9_by_tx(self): - mock_req = requestMock(path=b'/tx/0xa2a37fd2ab7048d70d51eaa8af2815e0e542400329b05a34274771174180a7e8') - res = self.app.get_by_tx(mock_req, '0xa2a37fd2ab7048d70d51eaa8af2815e0e542400329b05a34274771174180a7e8') + res = self.do_test_get("/v1/notifications/tx/0xa2a37fd2ab7048d70d51eaa8af2815e0e542400329b05a34274771174180a7e8") jsn = json.loads(res) self.assertEqual(jsn['total'], 1) results = jsn['results'] self.assertEqual(len(results), 1) def test_9_by_bad_tx(self): - mock_req = requestMock(path=b'/tx/2e4168cb2d563714d3f35ff76b7efc6c7d428360c97b6b45a18b5b1a4faa40') - res = self.app.get_by_tx(mock_req, b'2e4168cb2d563714d3f35ff76b7efc6c7d428360c97b6b45a18b5b1a4faa40') + res = self.do_test_get("/v1/notifications/tx/2e4168cb2d563714d3f35ff76b7efc6c7d428360c97b6b45a18b5b1a4faa40") jsn = json.loads(res) self.assertEqual(jsn['total'], 0) results = jsn['results'] @@ -101,79 +102,75 @@ def test_9_by_bad_tx(self): self.assertIn('Could not get tx with hash', jsn['message']) def test_get_by_contract(self): - mock_req = requestMock(path=b'/contract/b9fbcff6e50fd381160b822207231233dd3c56c2') - res = self.app.get_by_contract(mock_req, 'b9fbcff6e50fd381160b822207231233dd3c56c2') + res = self.do_test_get("/v1/notifications/contract/b9fbcff6e50fd381160b822207231233dd3c56c2") jsn = json.loads(res) self.assertEqual(jsn['total'], 1006) results = jsn['results'] self.assertEqual(len(results), 500) def test_get_by_contract_empty(self): - mock_req = requestMock(path=b'/contract/910cba960880c75072d0c625dfff459f72aae047') - res = self.app.get_by_contract(mock_req, '910cba960880c75072d0c625dfff459f72aae047') + res = self.do_test_get("/v1/notifications/contract/910cba960880c75072d0c625dfff459f72aae047") jsn = json.loads(res) self.assertEqual(jsn['total'], 0) results = jsn['results'] self.assertEqual(len(results), 0) def test_get_tokens(self): - mock_req = requestMock(path=b'/tokens') - res = self.app.get_tokens(mock_req) + res = self.do_test_get("/v1/tokens") jsn = json.loads(res) self.assertEqual(jsn['total'], 5) results = jsn['results'] self.assertIsInstance(results, list) def test_pagination_for_addr_results(self): - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 500) self.assertEqual(jsn['total_pages'], 3) - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=1') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=1") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 500) - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=2') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=2") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 500) - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=3') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?page=3") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 7) def test_pagination_page_size_for_addr_results(self): - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?pagesize=100') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?pagesize=100") jsn = json.loads(res) self.assertEqual(jsn['total'], 1007) results = jsn['results'] self.assertEqual(len(results), 100) self.assertEqual(jsn['total_pages'], 11) - mock_req = requestMock(path=b'/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?pagesize=100&page=11') - res = self.app.get_by_addr(mock_req, 'AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9') + res = self.do_test_get("/v1/notifications/addr/AXpNr3SDfLXbPHNdqxYeHK5cYpKMHZxMZ9?pagesize=100&page=11") jsn = json.loads(res) results = jsn['results'] self.assertEqual(len(results), 7) - def test_block_heigher_than_current(self): - mock_req = requestMock(path=b'/block/8000000') - res = self.app.get_by_block(mock_req, 800000) + def test_status(self): + res = self.do_test_get("/v1/status") jsn = json.loads(res) - self.assertEqual(jsn['total'], 0) - results = jsn['results'] - self.assertIsInstance(results, type(None)) - self.assertIn('Higher than current block', jsn['message']) + self.assertEqual(12356, jsn['current_height']) + self.assertEqual(settings.VERSION_NAME, jsn['version']) + self.assertEqual(0, jsn['num_peers']) + + def test_get_token(self): + res = self.do_test_get("/v1/token/b9fbcff6e50fd381160b822207231233dd3c56c2") + jsn = json.loads(res) + result = jsn['results'][0] + self.assertEqual(9479, result['block']) + self.assertEqual("NXT2", result['token']['symbol']) diff --git a/neo/api/utils.py b/neo/api/utils.py index da7836095..ae4fdd49f 100644 --- a/neo/api/utils.py +++ b/neo/api/utils.py @@ -1,5 +1,5 @@ -import json -import gzip +from aiohttp import web +from aiohttp.web_response import ContentCoding from functools import wraps COMPRESS_FASTEST = 1 @@ -13,35 +13,11 @@ def json_response(func): """ @json_response decorator adds header and dumps response object """ @wraps(func) - def wrapper(self, request, *args, **kwargs): - res = func(self, request, *args, **kwargs) - response_data = json.dumps(res) if isinstance(res, (dict, list)) else res - request.setHeader('Content-Type', 'application/json') - - if len(response_data) > COMPRESS_THRESHOLD: - accepted_encodings = request.requestHeaders.getRawHeaders('Accept-Encoding') - if accepted_encodings: - use_gzip = any("gzip" in encoding for encoding in accepted_encodings) - - if use_gzip: - response_data = gzip.compress(bytes(response_data, 'utf-8'), compresslevel=COMPRESS_FASTEST) - request.setHeader('Content-Encoding', 'gzip') - request.setHeader('Content-Length', len(response_data)) - - return response_data - - return wrapper - - -# @cors_header decorator to add the CORS headers -def cors_header(func): - """ @cors_header decorator adds CORS headers """ - - @wraps(func) - def wrapper(self, request, *args, **kwargs): - res = func(self, request, *args, **kwargs) - request.setHeader('Access-Control-Allow-Origin', '*') - request.setHeader('Access-Control-Allow-Headers', 'Content-Type, Access-Control-Allow-Headers, Authorization, X-Requested-With') - return res + async def wrapper(self, request, *args, **kwargs): + res = await func(self, request, *args, **kwargs) + response = web.json_response(data=res) + if response.content_length > COMPRESS_THRESHOLD: + response.enable_compression(force=ContentCoding.gzip) + return response return wrapper diff --git a/neo/bin/api_server.py b/neo/bin/api_server.py index e1d5efbd8..57af6a47d 100755 --- a/neo/bin/api_server.py +++ b/neo/bin/api_server.py @@ -14,7 +14,6 @@ See also: -* If you encounter any issues, please report them here: https://github.com/CityOfZion/neo-python/issues/273 * Server setup * Guide for Ubuntu server setup: https://gist.github.com/metachris/2be27cdff9503ebe7db1c27bfc60e435 * Systemd service config: https://gist.github.com/metachris/03d1cc47df7cddfbc4009d5249bdfc6c @@ -25,42 +24,29 @@ This api-server can log to stdout/stderr, logfile and syslog. Check `api-server.py -h` for more details. - -Twisted uses a quite custom logging setup. Here we simply setup the Twisted logger -to reuse our logzero logging setup. See also: - -* http://twisted.readthedocs.io/en/twisted-17.9.0/core/howto/logger.html -* https://twistedmatrix.com/documents/17.9.0/api/twisted.logger.STDLibLogObserver.html """ +import argparse +import asyncio import os import sys -import argparse -import threading -from time import sleep from logging.handlers import SysLogHandler import logzero from logzero import logger -from prompt_toolkit import prompt - -# Twisted logging -from twisted.logger import STDLibLogObserver, globalLogPublisher - -# Twisted and Klein methods and modules -from twisted.internet import reactor, task, endpoints, threads -from twisted.web.server import Site +from neo.Network.common import blocking_prompt as prompt +from aiohttp import web +from signal import SIGINT # neo methods and modules from neo.Core.Blockchain import Blockchain from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB -from neo.Wallets.utils import to_aes_key from neo.Implementations.Wallets.peewee.UserWallet import UserWallet - -from neo.Network.NodeLeader import NodeLeader +from neo.Network.p2pservice import NetworkService from neo.Settings import settings from neo.Utils.plugin import load_class_from_path -import neo.Settings +from neo.Wallets.utils import to_aes_key +from contextlib import suppress # Logfile default settings (only used if --logfile arg is used) LOGFILE_MAX_BYTES = 5e7 # 50 MB @@ -79,7 +65,7 @@ def write_pid_file(): f.write(str(os.getpid())) -def custom_background_code(): +async def custom_background_code(): """ Custom code run in a background thread. This function is run in a daemonized thread, which means it can be instantly killed at any @@ -87,36 +73,11 @@ def custom_background_code(): thread and handle exiting this thread in another way (eg. with signals and events). """ while True: - logger.info("[%s] Block %s / %s", settings.net_name, str(Blockchain.Default().Height + 1), str(Blockchain.Default().HeaderHeight + 1)) - sleep(15) - - -def on_persistblocks_error(err): - logger.debug("On Persist blocks loop error! %s " % err) + logger.info("[%s] Block %s / %s", settings.net_name, str(Blockchain.Default().Height), str(Blockchain.Default().HeaderHeight)) + await asyncio.sleep(15) -def stop_block_persisting(): - global continue_persisting - continue_persisting = False - - -def persist_done(value): - """persist callback. Value is unused""" - if continue_persisting: - sleep(0.1) - start_block_persisting() - else: - block_deferred.cancel() - - -def start_block_persisting(): - global block_deferred - block_deferred = threads.deferToThread(Blockchain.Default().PersistBlocks) - block_deferred.addCallback(persist_done) - block_deferred.addErrback(on_persistblocks_error) - - -def main(): +async def setup_and_start(loop): parser = argparse.ArgumentParser() # Network options @@ -145,7 +106,10 @@ def main(): parser.add_argument("--datadir", action="store", help="Absolute path to use for database directories") # peers - parser.add_argument("--maxpeers", action="store", default=5, + parser.add_argument("--minpeers", action="store", type=int, + help="Min peers to use for P2P Joining") + + parser.add_argument("--maxpeers", action="store", type=int, help="Max peers to use for P2P Joining") # If a wallet should be opened @@ -163,17 +127,17 @@ def main(): if not args.port_rpc and not args.port_rest: print("Error: specify at least one of --port-rpc / --port-rest") parser.print_help() - return + raise SystemExit if args.port_rpc == args.port_rest: print("Error: --port-rpc and --port-rest cannot be the same") parser.print_help() - return + raise SystemExit if args.logfile and (args.syslog or args.syslog_local): print("Error: Cannot only use logfile or syslog at once") parser.print_help() - return + raise SystemExit # Setting the datadir must come before setting the network, else the wrong path is checked at net setup. if args.datadir: @@ -191,13 +155,45 @@ def main(): elif args.coznet: settings.setup_coznet() - if args.maxpeers: + def set_min_peers(num_peers) -> bool: + try: + settings.set_min_peers(num_peers) + print("Minpeers set to ", num_peers) + return True + except ValueError: + print("Please supply a positive integer for minpeers") + return False + + def set_max_peers(num_peers) -> bool: try: - settings.set_max_peers(args.maxpeers) - print("Maxpeers set to ", args.maxpeers) + settings.set_max_peers(num_peers) + print("Maxpeers set to ", num_peers) + return True except ValueError: print("Please supply a positive integer for maxpeers") - return + return False + + minpeers = args.minpeers + maxpeers = args.maxpeers + + if minpeers and maxpeers: + if minpeers > maxpeers: + print("minpeers setting cannot be bigger than maxpeers setting") + return + if not set_min_peers(minpeers) or not set_max_peers(maxpeers): + return + elif minpeers: + if not set_min_peers(minpeers): + return + if minpeers > settings.CONNECTED_PEER_MAX: + if not set_max_peers(minpeers): + return + elif maxpeers: + if not set_max_peers(maxpeers): + return + if maxpeers < settings.CONNECTED_PEER_MIN: + if not set_min_peers(maxpeers): + return if args.syslog or args.syslog_local is not None: # Setup the syslog facility @@ -239,6 +235,7 @@ def main(): password_key = to_aes_key(passwd) try: wallet = UserWallet.Open(args.wallet, password_key) + asyncio.create_task(wallet.sync_wallet(start_block=wallet._current_height)) except Exception as e: print(f"Could not open wallet {e}") @@ -252,35 +249,16 @@ def main(): # Write a PID file to easily quit the service write_pid_file() - # Setup Twisted and Klein logging to use the logzero setup - observer = STDLibLogObserver(name=logzero.LOGZERO_DEFAULT_LOGGER) - globalLogPublisher.addObserver(observer) - - def loopingCallErrorHandler(error): - logger.info("Error in loop: %s " % error) - # Instantiate the blockchain and subscribe to notifications blockchain = LevelDBBlockchain(settings.chain_leveldb_path) Blockchain.RegisterBlockchain(blockchain) - start_block_persisting() + p2p = NetworkService() + p2p_task = loop.create_task(p2p.start()) + loop.create_task(custom_background_code()) - # If a wallet is open, make sure it processes blocks - if wallet: - walletdb_loop = task.LoopingCall(wallet.ProcessBlocks) - wallet_loop_deferred = walletdb_loop.start(1) - wallet_loop_deferred.addErrback(loopingCallErrorHandler) - - # Setup twisted reactor, NodeLeader and start the NotificationDB - reactor.suggestThreadPoolSize(15) - NodeLeader.Instance().Start() NotificationDB.instance().start() - # Start a thread with custom code - d = threading.Thread(target=custom_background_code) - d.setDaemon(True) # daemonizing the thread will kill it when the main thread is quit - d.start() - if args.port_rpc: logger.info("Starting json-rpc api server on http://%s:%s" % (args.host, args.port_rpc)) try: @@ -288,10 +266,12 @@ def loopingCallErrorHandler(error): except ValueError as err: logger.error(err) sys.exit() - api_server_rpc = rpc_class(args.port_rpc, wallet=wallet) + api_server_rpc = rpc_class(wallet=wallet) - endpoint_rpc = "tcp:port={0}:interface={1}".format(args.port_rpc, args.host) - endpoints.serverFromString(reactor, endpoint_rpc).listen(Site(api_server_rpc.app.resource())) + runner = web.AppRunner(api_server_rpc.app) + await runner.setup() + site = web.TCPSite(runner, args.host, args.port_rpc) + await site.start() if args.port_rest: logger.info("Starting REST api server on http://%s:%s" % (args.host, args.port_rest)) @@ -301,17 +281,50 @@ def loopingCallErrorHandler(error): logger.error(err) sys.exit() api_server_rest = rest_api() - endpoint_rest = "tcp:port={0}:interface={1}".format(args.port_rest, args.host) - endpoints.serverFromString(reactor, endpoint_rest).listen(Site(api_server_rest.app.resource())) + runner = web.AppRunner(api_server_rest.app) + await runner.setup() + site = web.TCPSite(runner, args.host, args.port_rpc) + await site.start() + + return wallet + + +async def shutdown(): + # cleanup any remaining tasks + all_tasks = asyncio.all_tasks() + for task in all_tasks: + task.cancel() + with suppress((asyncio.CancelledError, Exception)): + await task - reactor.addSystemEventTrigger('before', 'shutdown', stop_block_persisting) - reactor.run() - # After the reactor is stopped, gracefully shutdown the database. +def system_exit(): + raise SystemExit + + +def main(): + loop = asyncio.get_event_loop() + + # because a KeyboardInterrupt is so violent it can shutdown the DB in an unpredictable state. + loop.add_signal_handler(SIGINT, system_exit) + main_task = loop.create_task(setup_and_start(loop)) + + try: + loop.run_forever() + except SystemExit: + p2p = NetworkService() + loop.run_until_complete(p2p.shutdown()) + loop.run_until_complete(shutdown()) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.stop() + finally: + loop.close() + logger.info("Closing databases...") NotificationDB.close() Blockchain.Default().Dispose() - NodeLeader.Instance().Shutdown() + + wallet = main_task.result() if wallet: wallet.Close() diff --git a/neo/bin/import_blocks.py b/neo/bin/import_blocks.py index 6def20ea1..d89772c45 100644 --- a/neo/bin/import_blocks.py +++ b/neo/bin/import_blocks.py @@ -15,9 +15,10 @@ from tqdm import trange from prompt_toolkit import prompt from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB +import asyncio -def main(): +async def main(): parser = argparse.ArgumentParser() parser.add_argument("-m", "--mainnet", action="store_true", default=False, help="use MainNet instead of the default TestNet") @@ -142,7 +143,7 @@ def main(): # add if block.Index > start_block: - chain.AddBlockDirectly(block, do_persist_complete=store_notifications) + await chain.TryPersist(block) # reset blockheader block._header = None @@ -183,4 +184,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/neo/bin/prompt.py b/neo/bin/prompt.py index ff3dc3908..f08fee38a 100755 --- a/neo/bin/prompt.py +++ b/neo/bin/prompt.py @@ -4,16 +4,18 @@ import datetime import os import traceback +import asyncio +import termios +import sys from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import FileHistory from prompt_toolkit.shortcuts import print_formatted_text, PromptSession from prompt_toolkit.formatted_text import FormattedText -from twisted.internet import reactor, task +from prompt_toolkit.application import get_app as prompt_toolkit_get_app from neo import __version__ from neo.Core.Blockchain import Blockchain from neo.Implementations.Blockchains.LevelDB.LevelDBBlockchain import LevelDBBlockchain from neo.Implementations.Notifications.LevelDB.NotificationDB import NotificationDB -from neo.Network.NodeLeader import NodeLeader from neo.Prompt.Commands.Wallet import CommandWallet from neo.Prompt.Commands.Show import CommandShow from neo.Prompt.Commands.Search import CommandSearch @@ -25,9 +27,14 @@ from neo.UserPreferences import preferences from neo.logging import log_manager from neo.Prompt.PromptPrinter import prompt_print, token_style +from neo.Network.nodemanager import NodeManager logger = log_manager.getLogger() +from prompt_toolkit.eventloop import use_asyncio_event_loop +from neo.Network.p2pservice import NetworkService +from contextlib import suppress + class PromptFileHistory(FileHistory): def append(self, string): @@ -84,6 +91,8 @@ class PromptInterface: start_height = None start_dt = None + prompt_session = None + def __init__(self, history_filename=None): PromptData.Prompt = self if history_filename: @@ -134,9 +143,7 @@ def quit(self): print('Shutting down. This may take a bit...') self.go_on = False PromptData.close_wallet() - Blockchain.Default().Dispose() - NodeLeader.Instance().Shutdown() - reactor.stop() + raise SystemExit def help(self): prompt_print(f"\nCommands:") @@ -145,26 +152,13 @@ def help(self): prompt_print(f" {command_group:<15} - {command.command_desc().short_help}") prompt_print(f"\nRun 'COMMAND help' for more information on a command.") - def start_wallet_loop(self): - if self.wallet_loop_deferred: - self.stop_wallet_loop() - self.walletdb_loop = task.LoopingCall(PromptData.Wallet.ProcessBlocks) - self.wallet_loop_deferred = self.walletdb_loop.start(1) - self.wallet_loop_deferred.addErrback(self.on_looperror) - - def stop_wallet_loop(self): - self.wallet_loop_deferred.cancel() - self.wallet_loop_deferred = None - if self.walletdb_loop and self.walletdb_loop.running: - self.walletdb_loop.stop() - def on_looperror(self, err): logger.debug("On DB loop error! %s " % err) - def run(self): - dbloop = task.LoopingCall(Blockchain.Default().PersistBlocks) - dbloop_deferred = dbloop.start(.1) - dbloop_deferred.addErrback(self.on_looperror) + async def run(self): + nodemgr = NodeManager() + while not nodemgr.running: + await asyncio.sleep(0.1) tokens = [("class:neo", 'NEO'), ("class:default", ' cli. Type '), ("class:command", '\'help\' '), ("class:default", 'to get started')] @@ -173,24 +167,39 @@ def run(self): print('\n') - while self.go_on: - - session = PromptSession("neo> ", - completer=self.get_completer(), - history=self.history, - bottom_toolbar=self.get_bottom_toolbar, - style=token_style, - refresh_interval=3, - ) + session = PromptSession("neo> ", + completer=self.get_completer(), + history=self.history, + bottom_toolbar=self.get_bottom_toolbar, + style=token_style, + refresh_interval=3, + ) + self.prompt_session = session + result = "" + while self.go_on: + # with patch_stdout(): try: - result = session.prompt() + result = await session.prompt(async_=True) except EOFError: # Control-D pressed: quit return self.quit() except KeyboardInterrupt: - # Control-C pressed: do nothing - continue + # Control-C pressed: pause for user input + + # temporarily mute stdout during user input + # components like `network` set at DEBUG level will spam through the console + # making it impractical to input user data + log_manager.mute_stdio() + + print('Logging output muted during user input...') + try: + result = await session.prompt(async_=True) + except Exception as e: + logger.error("Exception handling input: %s " % e) + + # and re-enable stdio + log_manager.unmute_stdio() except Exception as e: logger.error("Exception handling input: %s " % e) @@ -253,7 +262,10 @@ def main(): help="Absolute path to use for database directories") # peers - parser.add_argument("--maxpeers", action="store", default=5, + parser.add_argument("--minpeers", action="store", type=int, + help="Min peers to use for P2P Joining") + + parser.add_argument("--maxpeers", action="store", type=int, default=5, help="Max peers to use for P2P Joining") # Show the neo-python version @@ -294,8 +306,49 @@ def main(): if args.verbose: settings.set_log_smart_contract_events(True) - if args.maxpeers: - settings.set_max_peers(args.maxpeers) + def set_min_peers(num_peers) -> bool: + try: + settings.set_min_peers(num_peers) + print("Minpeers set to ", num_peers) + return True + except ValueError: + print("Please supply a positive integer for minpeers") + return False + + def set_max_peers(num_peers) -> bool: + try: + settings.set_max_peers(num_peers) + print("Maxpeers set to ", num_peers) + return True + except ValueError: + print("Please supply a positive integer for maxpeers") + return False + + minpeers = args.minpeers + maxpeers = args.maxpeers + + if minpeers and maxpeers: + if minpeers > maxpeers: + print("minpeers setting cannot be bigger than maxpeers setting") + return + if not set_min_peers(minpeers) or not set_max_peers(maxpeers): + return + elif minpeers: + if not set_min_peers(minpeers): + return + if minpeers > settings.CONNECTED_PEER_MAX: + if not set_max_peers(minpeers): + return + elif maxpeers: + if not set_max_peers(maxpeers): + return + if maxpeers < settings.CONNECTED_PEER_MIN: + if not set_min_peers(maxpeers): + return + + loop = asyncio.get_event_loop() + # put prompt_toolkit on top of asyncio to avoid blocking + use_asyncio_event_loop() # Instantiate the blockchain and subscribe to notifications blockchain = LevelDBBlockchain(settings.chain_leveldb_path) @@ -309,19 +362,45 @@ def main(): fn_prompt_history = os.path.join(settings.DATA_DIR_PATH, '.prompt.py.history') cli = PromptInterface(fn_prompt_history) - # Run things - - reactor.callInThread(cli.run) - - NodeLeader.Instance().Start() + cli_task = loop.create_task(cli.run()) + p2p = NetworkService() + loop.create_task(p2p.start()) + + async def shutdown(): + all_tasks = asyncio.all_tasks() + for task in all_tasks: + task.cancel() + with suppress(asyncio.CancelledError): + await task + + # prompt_toolkit hack for not cleaning up see: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/787 + old_attrs = termios.tcgetattr(sys.stdin) + + try: + loop.run_forever() + except SystemExit: + pass + finally: + with suppress(asyncio.InvalidStateError): + app = prompt_toolkit_get_app() + if app.is_running: + app.exit() + with suppress((SystemExit, Exception)): + cli_task.exception() + loop.run_until_complete(p2p.shutdown()) + loop.run_until_complete(shutdown()) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.stop() + loop.close() - # reactor.run() is blocking, until `quit()` is called which stops the reactor. - reactor.run() + # Run things # After the reactor is stopped, gracefully shutdown the database. NotificationDB.close() Blockchain.Default().Dispose() - NodeLeader.Instance().Shutdown() + + # clean up prompt_toolkit mess, see above + termios.tcsetattr(sys.stdin, termios.TCSANOW, old_attrs) if __name__ == "__main__": diff --git a/neo/bin/test_prompt.py b/neo/bin/test_prompt.py deleted file mode 100644 index 3fe0f6aa9..000000000 --- a/neo/bin/test_prompt.py +++ /dev/null @@ -1,27 +0,0 @@ -from unittest import TestCase, skip -import pexpect - - -class PromptTest(TestCase): - - @skip("Unreliable due to system resource dependency. Replace later with better alternative") - def test_prompt_run(self): - child = pexpect.spawn('python neo/bin/prompt.py') - child.expect([pexpect.EOF, pexpect.TIMEOUT], timeout=10) # if test is failing consider increasing timeout time - before = child.before - text = before.decode('utf-8', 'ignore') - checktext = "neo>" - self.assertIn(checktext, text) - child.terminate() - - @skip("Unreliable due to system resource dependency. Replace later with better alternative") - def test_prompt_open_wallet(self): - child = pexpect.spawn('python neo/bin/prompt.py') - child.send('open wallet fixtures/testwallet.db3\n') - child.send('testpassword\n') - child.expect([pexpect.EOF, pexpect.TIMEOUT], timeout=15) # if test is failing consider increasing timeout time - before = child.before - text = before.decode('utf-8', 'ignore') - checktext = "Opened" - self.assertIn(checktext, text) - child.terminate() diff --git a/neo/logging.py b/neo/logging.py index 25a8764b9..df360315c 100644 --- a/neo/logging.py +++ b/neo/logging.py @@ -30,7 +30,7 @@ logger.info("I log for generic components like the prompt or Util classes") network_logger = log_manager.getLogger('network') - logger.info("I log for network classes like NodeLeader and NeoNode") + logger.info("I log for network classes like NeoNode and SyncManager") # since network classes can be very active and verbose, we might want to raise the level to just show ERROR or above logconfig = ('network', logging.ERROR) # a tuple of (`component name`, `log level`) diff --git a/requirements.txt b/requirements.txt index a75611def..d90709e59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,101 +1,42 @@ aenum==2.1.2 -alabaster==0.7.12 -appnope==0.1.0 -asn1crypto==0.24.0 +aiohttp==3.5.4 +aiohttp-cors==0.7.0 astor==0.7.1 asynctest==0.13.0 +async-timeout==3.0.1 attrs==19.1.0 -Automat==0.7.0 -autopep8==1.4.4 -Babel==2.6.0 -backcall==0.1.0 +autopep8==1.4.3 base58==1.0.3 bitcoin==1.1.42 -blessings==1.7 -bpython==0.18 -bumpversion==0.5.3 certifi==2019.3.9 -cffi==1.12.3 chardet==3.0.4 -colorlog==4.0.2 -constantly==15.1.0 coverage==4.5.3 coveralls==1.7.0 coz-bytecode==0.5.1 -cryptography==2.6.1 -curtsies==0.3.0 -cycler==0.10.0 -decorator==4.4.0 docopt==0.6.2 -docutils==0.14 ecdsa==0.13 -Events==0.3 -furl==2.0.0 -gevent==1.4.0 -greenlet==0.4.15 -hyperlink==19.0.0 idna==2.8 -imagesize==1.1.0 -incremental==17.5.0 -ipython==7.5.0 -ipython-genutils==0.2.0 -jedi==0.13.3 -Jinja2==2.10.1 -klein==17.10.0 logzero==1.5.0 -MarkupSafe==1.1.1 -memory-profiler==0.55.0 mmh3==2.5.1 -mock==3.0.5 +mock==2.0.0 mpmath==1.1.0 -neo-boa==0.5.6 --e git+https://github.com/ixje/neo-python.git@4a5d8d2e005a3f60b3542931e65517987d7d1446#egg=neo_python +multidict==4.5.2 +git+https://github.com/ixje/neo-boa@01ea0207250c8ee96bca673e6e587796888e58d1#egg=neo-boa neo-python-rpc==0.2.1 -neocore==0.5.6 -neopython-extended-rpc-server==0.1.0 -orderedmultidict==1.0 -packaging==19.0 -parso==0.4.0 -pbr==5.2.0 -peewee==3.9.5 -pexpect==4.7.0 -pickleshare==0.7.5 -pip-review==1.0 -pluggy==0.11.0 -plyvel==1.1.0 +pbr==5.1.3 +peewee==3.9.2 +plyvel==1.0.5 prompt-toolkit==2.0.9 -psutil==5.6.2 -ptyprocess==0.6.0 -py==1.8.0 +psutil==5.6.1 pycodestyle==2.5.0 -pycparser==2.19 -pycryptodome==3.7.2 -Pygments==2.4.0 -PyHamcrest==1.9.0 +pycryptodome==3.7.3 pymitter==0.2.3 -Pympler==0.7 -pyparsing==2.4.0 -python-dateutil==2.8.0 -pytz==2019.1 +pyparsing==2.3.1 +pytz==2018.9 requests==2.21.0 scrypt==0.8.6 six==1.12.0 -snowballstemmer==1.2.1 -Sphinx==2.0.1 -sphinx-rtd-theme==0.4.3 -sphinxcontrib-applehelp==1.0.1 -sphinxcontrib-devhelp==1.0.1 -sphinxcontrib-htmlhelp==1.0.2 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.2 -sphinxcontrib-serializinghtml==1.1.3 -sphinxcontrib-websupport==1.1.0 tqdm==4.29.1 -traitlets==4.3.2 -Twisted==18.9.0 -typing==3.6.6 -urllib3==1.24.2 -virtualenv==16.6.0 +urllib3==1.24.1 wcwidth==0.1.7 -Werkzeug==0.15.4 -zope.interface==4.6.0 +yarl==1.3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index a0c868e74..ec2bc2361 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name='neo-python', - python_requires='>=3.6', + python_requires='>=3.7', version='0.8.5-dev', description="Python Node and SDK for the NEO blockchain", long_description=readme, @@ -52,7 +52,6 @@ 'License :: OSI Approved :: MIT License', 'Natural Language :: English', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', ] )