You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
278 lines
7.6 KiB
Python
278 lines
7.6 KiB
Python
import socket, ssl, msgpack, tempfile, os
|
|
from collections import deque
|
|
from filelike import FileLike
|
|
|
|
class BaseClient:
|
|
# Overridable settings
|
|
max_mem = 32 * 1024 * 1024 # Maximum amount of memory per RAM-based temp file
|
|
chunk_size = 1024 # Size per chunk of raw datastream
|
|
recv_size = 1024 # Amount of data to receive at once
|
|
|
|
def __init__(self, host=None, port=None, use_ssl=False, allowed_certs=None, conn=None, source=None, **kwargs):
|
|
# Internal variables
|
|
self._tempbuff = b""
|
|
self._read_left = 0
|
|
self._read_buff = b""
|
|
self._tempfiles = {}
|
|
self._last_datastream_id = 10
|
|
self._active_datastreams = {}
|
|
self._filelike_counter = 0
|
|
self._datastream_started = []
|
|
|
|
# Parent reactor
|
|
self.reactor = None
|
|
|
|
self.objtype = "client"
|
|
self.sendq = deque([])
|
|
|
|
if (host is None or port is None) and (conn is None or source is None):
|
|
raise Exception("You must specify either a connection and source address, or a hostname and port.")
|
|
|
|
if host is not None:
|
|
# Treat this as a new client
|
|
self.host = host
|
|
self.port = port
|
|
self.ssl = use_ssl
|
|
self.spawned = False
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
if self.ssl == True:
|
|
self.stream = ssl.wrap_socket(sock, cert_reqs=ssl.CERT_REQUIRED, ca_certs=allowed_certs)
|
|
else:
|
|
self.stream = sock
|
|
|
|
self.stream.connect((self.host, self.port))
|
|
|
|
self.event_connected()
|
|
|
|
elif conn is not None:
|
|
# Treat this as a client spawned by a server
|
|
self.host = source[0]
|
|
self.port = source[1]
|
|
self.stream = conn
|
|
self.spawned = True
|
|
self.event_connected()
|
|
|
|
def _send_chunk(self, chunk):
|
|
self.stream.send(chunk)
|
|
|
|
def _encode_header(self, chunktype, size, channel):
|
|
bits = ""
|
|
bits += bin(chunktype)[2:].zfill(7)
|
|
bits += bin(size)[2:].zfill(25)
|
|
bits += bin(channel)[2:].zfill(24)
|
|
|
|
header = b""
|
|
|
|
for i in xrange(0, 7):
|
|
header += chr(int(bits[i*8:(i+1)*8], 2))
|
|
|
|
return header
|
|
|
|
def _decode_header(self, header):
|
|
bits = ""
|
|
|
|
for i in xrange(0, len(header)):
|
|
bits += bin(ord(header[i]))[2:].zfill(8)
|
|
|
|
chunktype = int(bits[:7], 2) # Bits 0-6 inc
|
|
chunksize = int(bits[7:32], 2) # Bits 7-31 inc
|
|
channel = int(bits[32:56], 2) # Bits 32-55 inc
|
|
|
|
return (chunktype, chunksize, channel)
|
|
|
|
def _pack(self, data):
|
|
return msgpack.packb(data, default=self._encode_pack)
|
|
|
|
def _encode_pack(self, obj):
|
|
if hasattr(obj, "read"):
|
|
datastream_id = self._create_datastream(obj)
|
|
|
|
# Determine the total size of the file
|
|
current_pos = obj.tell()
|
|
obj.seek(0, os.SEEK_END)
|
|
total_size = obj.tell()
|
|
obj.seek(current_pos)
|
|
|
|
obj = {"__type__": "file", "__id__": datastream_id, "__size__": total_size}
|
|
|
|
return obj
|
|
|
|
def _unpack(self, data):
|
|
return msgpack.unpackb(data, object_hook=self._decode_unpack)
|
|
|
|
def _decode_unpack(self, obj):
|
|
if "__type__" in obj:
|
|
if obj['__type__'] == "file":
|
|
# TODO: Validate item
|
|
datastream_id = obj['__id__']
|
|
size = obj['__size__']
|
|
self._create_tempfile(datastream_id)
|
|
obj = self._tempfiles[datastream_id]
|
|
obj._total_size = size
|
|
|
|
self._datastream_started.append(datastream_id)
|
|
self.event_datastream_start(obj, size)
|
|
|
|
return obj
|
|
|
|
def _create_datastream(self, obj):
|
|
datastream_id = self._get_datastream_id()
|
|
self._active_datastreams[datastream_id] = obj
|
|
return datastream_id
|
|
|
|
def _get_datastream_id(self):
|
|
self._last_datastream_id += 1
|
|
|
|
if self._last_datastream_id > 10000:
|
|
self._last_datastream_id = 10
|
|
|
|
if self._last_datastream_id in self._active_datastreams:
|
|
return self._get_datastream_id()
|
|
|
|
return self._last_datastream_id
|
|
|
|
def _create_tempfile(self, datastream_id):
|
|
# This creates a temporary file for the specified datastream if it does not already exist.
|
|
if datastream_id not in self._tempfiles:
|
|
self._filelike_counter += 1
|
|
self._tempfiles[datastream_id] = FileLike(tempfile.SpooledTemporaryFile(max_size=self.max_mem), self._filelike_counter)
|
|
|
|
def _receive_datastream(self, datastream_id, data):
|
|
self._create_tempfile(datastream_id)
|
|
obj = self._tempfiles[datastream_id]
|
|
obj.write(data)
|
|
obj._bytes_finished += len(data)
|
|
|
|
if datastream_id in self._datastream_started:
|
|
self.event_datastream_progress(obj, obj._bytes_finished, obj._total_size)
|
|
|
|
def _send_system_message(self, data):
|
|
encoded = self._pack(data)
|
|
header = self._encode_header(3, len(encoded), 1)
|
|
self.sendq.append(header + encoded)
|
|
|
|
def _read_cycle(self):
|
|
fileno = self.stream.fileno()
|
|
|
|
while True:
|
|
try:
|
|
buff = self.stream.recv(self.recv_size)
|
|
break
|
|
except ssl.SSLError, err:
|
|
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
|
|
select.select([self.stream], [], [])
|
|
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
|
|
select.select([], [self.stream], [])
|
|
else:
|
|
raise
|
|
|
|
if not buff:
|
|
# The client has ceased to exist - most likely it has closed the connection.
|
|
del self.reactor.objmap[fileno]
|
|
self.reactor.queue.remove(self.stream)
|
|
self.event_disconnected()
|
|
|
|
buff = self._tempbuff + buff
|
|
self._tempbuff = b""
|
|
|
|
while buff != b"":
|
|
if self._read_left > 0:
|
|
# Continue reading a chunk in progress
|
|
if self._read_left <= len(buff):
|
|
# All the data we need is in the buffer.
|
|
self._read_buff += buff[:self._read_left]
|
|
buff = buff[self._read_left:]
|
|
|
|
self._read_left = 0
|
|
|
|
self._process_chunk(self._chunktype, self._channel, self._read_buff)
|
|
|
|
self._read_buff = b""
|
|
else:
|
|
# We need to read more data than is in the buffer.
|
|
self._read_buff += buff
|
|
self._read_left -= len(buff)
|
|
buff = b""
|
|
elif len(buff) >= 7:
|
|
# Start reading a new chunk
|
|
header = buff[:7]
|
|
|
|
chunktype, chunksize, channel = self._decode_header(header)
|
|
|
|
buff = buff[7:]
|
|
self._read_left = chunksize
|
|
self._chunksize = chunksize
|
|
self._chunktype = chunktype
|
|
self._channel = channel
|
|
|
|
else:
|
|
# We need more data to do anything meaningful
|
|
self._tempbuff = buff
|
|
buff = b""
|
|
|
|
def _process_chunk(self, chunktype, channel, data):
|
|
if chunktype == 1:
|
|
# Client message
|
|
self.event_receive(self._unpack(data))
|
|
elif chunktype == 2:
|
|
# Datastream chunk
|
|
self._receive_datastream(channel, data)
|
|
elif chunktype == 3:
|
|
# System message
|
|
self._process_system_message(msgpack.unpackb(data))
|
|
|
|
def _process_system_message(self, data):
|
|
if data['type'] == "datastream_finished":
|
|
datastream_id = data['id']
|
|
self.event_datastream_finished(self._tempfiles[datastream_id])
|
|
self._datastream_started.remove(datastream_id)
|
|
del self._tempfiles[datastream_id]
|
|
|
|
def _write_cycle(self):
|
|
if len(self.sendq) > 0:
|
|
item = self.sendq.popleft()
|
|
self._send_chunk(item)
|
|
|
|
if len(self._active_datastreams) > 0:
|
|
to_delete = []
|
|
|
|
for datastream_id, datastream in self._active_datastreams.iteritems():
|
|
data = datastream.read(self.chunk_size)
|
|
|
|
if data:
|
|
header = self._encode_header(2, len(data), datastream_id)
|
|
|
|
self._send_chunk(header + data)
|
|
else:
|
|
# Done with this datastream.
|
|
self._send_system_message({"type": "datastream_finished", "id": datastream_id})
|
|
to_delete.append(datastream_id)
|
|
|
|
for datastream_id in to_delete:
|
|
del self._active_datastreams[datastream_id]
|
|
|
|
def send(self, data):
|
|
encoded = self._pack(data)
|
|
header = self._encode_header(1, len(encoded), 0)
|
|
self.sendq.append(header + encoded)
|
|
|
|
def event_connected(self):
|
|
pass
|
|
|
|
def event_disconnected(self):
|
|
pass
|
|
|
|
def event_receive(self, data):
|
|
pass
|
|
|
|
def event_datastream_start(self, stream, total_bytes):
|
|
pass
|
|
|
|
def event_datastream_progress(self, stream, finished_bytes, total_bytes):
|
|
pass
|
|
|
|
def event_datastream_finished(self, stream):
|
|
pass
|