Skip to content
45 changes: 24 additions & 21 deletions protocoin/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cStringIO import StringIO
from io import BytesIO
from .serializers import *
from .exceptions import NodeDisconnectException
import os
Expand All @@ -15,7 +15,7 @@ class BitcoinBasicClient(object):

def __init__(self, socket):
self.socket = socket
self.buffer = StringIO()
self.buffer = BytesIO()

def close_stream(self):
"""This method will close the socket stream."""
Expand Down Expand Up @@ -52,7 +52,7 @@ def receive_message(self):
return

# Go to the beginning of the buffer
self.buffer.reset()
self.buffer.seek(0)

message_model = None
message_header_serial = MessageHeaderSerializer()
Expand All @@ -66,7 +66,9 @@ def receive_message(self):
return

payload = self.buffer.read(message_header.length)
self.buffer = StringIO()
buffer = BytesIO()
buffer.write(self.buffer.read())
self.buffer = buffer
self.handle_message_header(message_header, payload)

payload_checksum = \
Expand All @@ -78,7 +80,7 @@ def receive_message(self):

if message_header.command in MESSAGE_MAPPING:
deserializer = MESSAGE_MAPPING[message_header.command]()
message_model = deserializer.deserialize(StringIO(payload))
message_model = deserializer.deserialize(BytesIO(payload))

return (message_header, message_model)

Expand All @@ -89,7 +91,7 @@ def send_message(self, message):

:param message: The message object to send
"""
bin_data = StringIO()
bin_data = BytesIO()
message_header = MessageHeader(self.coin)
message_header_serial = MessageHeaderSerializer()

Expand Down Expand Up @@ -118,21 +120,22 @@ def loop(self):
raise NodeDisconnectException("Node disconnected.")

self.buffer.write(data)
data = self.receive_message()

# Check if the message is still incomplete to parse
if data is None:
continue

# Check for the header and message
message_header, message = data
if not message:
continue

handle_func_name = "handle_" + message_header.command
handle_func = getattr(self, handle_func_name, None)
if handle_func:
handle_func(message_header, message)
while True:
data = self.receive_message()

# Check if the message is still incomplete to parse
if data is None:
break

# Check for the header and message
message_header, message = data
if not message:
continue

handle_func_name = "handle_" + message_header.command
handle_func = getattr(self, handle_func_name, None)
if handle_func:
handle_func(message_header, message)

class BitcoinClient(BitcoinBasicClient):
"""This class implements all the protocol rules needed
Expand Down
31 changes: 15 additions & 16 deletions protocoin/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .exceptions import NodeDisconnectException

from cStringIO import StringIO
from io import BytesIO
import struct
import time
import random
Expand Down Expand Up @@ -157,12 +157,11 @@ def parse(self, value):

def deserialize(self, stream):
data = stream.read(self.length)
return data.split("\x00", 1)[0]
return data[:(data+b'\x00').index(b'\x00')].decode("utf-8")

def serialize(self):
bin_data = StringIO()
bin_data.write(self.value[:self.length])
bin_data.write("\x00" * (12 - len(self.value)))
bin_data = BytesIO()
bin_data.write(struct.pack("12s", self.value.encode("utf-8")))
return bin_data.getvalue()

class NestedField(Field):
Expand Down Expand Up @@ -211,7 +210,7 @@ def parse(self, value):
self.value = value

def serialize(self):
bin_data = StringIO()
bin_data = BytesIO()
self.var_int.parse(len(self))
bin_data.write(self.var_int.serialize())
serializer = self.serializer_class()
Expand All @@ -223,7 +222,7 @@ def deserialize(self, stream):
count = self.var_int.deserialize(stream)
items = []
serializer = self.serializer_class()
for i in xrange(count):
for i in range(count):
data = serializer.deserialize(stream)
items.append(data)
return items
Expand All @@ -236,7 +235,7 @@ def __len__(self):

class IPv4AddressField(Field):
"""An IPv4 address field without timestamp and reserved IPv6 space."""
reserved = "\x00"*10 + "\xff"*2
reserved = b"\x00"*10 + b"\xff"*2

def parse(self, value):
self.value = value
Expand All @@ -247,7 +246,7 @@ def deserialize(self, stream):
return socket.inet_ntoa(addr)

def serialize(self):
bin_data = StringIO()
bin_data = BytesIO()
bin_data.write(self.reserved)
bin_data.write(socket.inet_aton(self.value))
return bin_data.getvalue()
Expand All @@ -273,12 +272,12 @@ def deserialize(self, stream):

def serialize(self):
if self.value < 0xFD:
return chr(self.value)
return struct.pack("B", self.value)
if self.value <= 0xFFFF:
return chr(0xFD) + struct.pack("<H", self.value)
return b'\xFD' + struct.pack("<H", self.value)
if self.value <= 0xFFFFFFFF:
return chr(0xFE) + struct.pack("<I", self.value)
return chr(0xFF) + struct.pack("<Q", self.value)
return b'\xFE' + struct.pack("<I", self.value)
return b'\xFF' + struct.pack("<Q", self.value)

class VariableStringField(Field):
"""A variable length string field."""
Expand All @@ -297,9 +296,9 @@ def deserialize(self, stream):

def serialize(self):
self.var_int.parse(len(self))
bin_data = StringIO()
bin_data = BytesIO()
bin_data.write(self.var_int.serialize())
bin_data.write(self.value)
bin_data.write(bytes(self.value, "utf-8"))
return bin_data.getvalue()

def __len__(self):
Expand All @@ -323,7 +322,7 @@ def deserialize(self, stream):

def serialize(self):
hash_ = self.value
bin_data = StringIO()
bin_data = BytesIO()
for i in range(8):
pack_data = struct.pack(self.datatype, hash_ & 0xFFFFFFFF)
bin_data.write(pack_data)
Expand Down
27 changes: 13 additions & 14 deletions protocoin/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import hashlib
import struct
from cStringIO import StringIO
from io import BytesIO
from collections import OrderedDict

from . import fields
Expand All @@ -21,7 +21,7 @@ def get_fields(meta, bases, attrs, field_class):
"""This method will construct an ordered dict with all
the fields present on the serializer classes."""
fields = [(field_name, attrs.pop(field_name))
for field_name, field_value in list(attrs.iteritems())
for field_name, field_value in list(attrs.items())
if isinstance(field_value, field_class)]

for base_cls in bases[::-1]:
Expand All @@ -31,9 +31,8 @@ def get_fields(meta, bases, attrs, field_class):
fields.sort(key=lambda it: it[1].count)
return OrderedDict(fields)

class SerializerABC(object):
class SerializerABC(object, metaclass = SerializerMeta):
"""The serializer abstract base class."""
__metaclass__ = SerializerMeta

class Serializer(SerializerABC):
"""The main serializer class, inherit from this class to
Expand All @@ -50,8 +49,8 @@ def serialize(self, obj, fields=None):

:param obj: The object to serializer.
"""
bin_data = StringIO()
for field_name, field_obj in self._fields.iteritems():
bin_data = BytesIO()
for field_name, field_obj in self._fields.items():
if fields:
if field_name not in fields:
continue
Expand All @@ -65,10 +64,10 @@ def deserialize(self, stream):
"""This method will read the stream and then will deserialize the
binary data information present on it.

:param stream: A file-like object (StringIO, file, socket, etc.)
:param stream: A file-like object (BytesIO, file, socket, etc.)
"""
model = self.model_class()
for field_name, field_obj in self._fields.iteritems():
for field_name, field_obj in self._fields.items():
value = field_obj.deserialize(stream)
setattr(model, field_name, value)
return model
Expand All @@ -83,7 +82,7 @@ def __init__(self, coin="bitcoin"):

def _magic_to_text(self):
"""Converts the magic value to a textual representation."""
for k, v in fields.MAGIC_VALUES.iteritems():
for k, v in fields.MAGIC_VALUES.items():
if v == self.magic:
return k
return "Unknown Magic"
Expand Down Expand Up @@ -127,7 +126,7 @@ def _services_to_text(self):
"""Converts the services field into a textual
representation."""
services = []
for service_name, flag_mask in fields.SERVICES.iteritems():
for service_name, flag_mask in fields.SERVICES.items():
if self.services & flag_mask:
services.append(service_name)
return services
Expand All @@ -150,7 +149,7 @@ class IPv4AddressTimestamp(IPv4Address):
"""The IPv4 Address with timestamp."""
def __init__(self):
super(IPv4AddressTimestamp, self).__init__()
self.timestamp = time.time()
self.timestamp = int(time.time())

def __repr__(self):
services = self._services_to_text()
Expand All @@ -174,7 +173,7 @@ class Version(object):
def __init__(self):
self.version = fields.PROTOCOL_VERSION
self.services = fields.SERVICES["NODE_NETWORK"]
self.timestamp = time.time()
self.timestamp = int(time.time())
self.addr_recv = IPv4Address()
self.addr_from = IPv4Address()
self.nonce = random.randint(0, 2**32-1)
Expand Down Expand Up @@ -241,7 +240,7 @@ def __init__(self):

def type_to_text(self):
"""Converts the inventory type to text representation."""
for k, v in fields.INVENTORY_TYPE.iteritems():
for k, v in fields.INVENTORY_TYPE.items():
if v == self.inv_type:
return k
return "Unknown Type"
Expand Down Expand Up @@ -534,4 +533,4 @@ class GetAddrSerializer(Serializer):
"headers": HeaderVectorSerializer,
"mempool": MemPoolSerializer,
"getaddr": GetAddrSerializer,
}
}