viceroy

git clone git://git.codymlewis.com/viceroy.git
Log | Files | Refs | README

commit aa99436619df2f1131c53cc45e9de0593e262f11
parent 8e73a3446d2b11b2c82bf0e9a6f8f50e760e0197
Author: Cody Lewis <cody@codymlewis.com>
Date:   Fri, 16 Oct 2020 17:04:52 +1100

Implemented network programming form of the system

Diffstat:
MClient.py | 64+++++++++++++++++++++++++++++++++++++++++++++++++++++++---------
MGlobalModel.py | 2+-
MServer.py | 69+++++++++++++++++++++++++++++++++++++++++++++++++++++----------------
MSoftMaxModel.py | 4++++
4 files changed, 113 insertions(+), 26 deletions(-)

diff --git a/Client.py b/Client.py @@ -1,20 +1,66 @@ +""" +Classes and functions for the client networking aspect of federated learning + +Author: Cody Lewis +""" + import socket +import pickle + +import torch +import torch.nn as nn import SoftMaxModel class Client: - def __init__(self, num_in, num_out): - self.net = SoftMaxModel.SoftMaxModel(num_in, num_out) + def __init__(self, x, y): + self.net = SoftMaxModel.SoftMaxModel(len(x[0]), len(y[0])) + self.x = x + self.y = y + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) def connect(self, host, port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.connect((host, port)) - s.sendall(b'Hello, there') - print(s.recv(1024)) + """Connect to the host:port federated learning server""" + self.socket.connect((host, port)) + self.net.copy_params(pickle.loads(self.socket.recv(1024))) + + def fit(self): + if self.socket.recv(1024) != b'OK': + pass + criterion = nn.BCELoss() + e = 0 + while True: + e += 1 + history, grads = self.net.fit(self.x, self.y, 1, verbose=False) + self.socket.sendall(pickle.dumps(grads)) + print( + f"Epoch: {e}, Loss: {criterion(self.net(X), Y)}", + end="\r" + ) + if self.socket.recv(1024) != b'OK': + break + # An improvement would be to save grads as a backlog and concurrent + # send them when the server is ready + print() if __name__ == '__main__': - client = Client(2, 2) - client.connect('127.0.0.1', 5000) + X = torch.tensor([ + [0, 0], + [0, 1], + [1, 0], + [1, 1] + ], dtype=torch.float32) + Y = torch.tensor([ + [1, 0], + [1, 0], + [1, 0], + [0, 1] + ], dtype=torch.float32) + client = Client(X, Y) + HOST, PORT = '127.0.0.1', 5000 + print(f"Connecting to {HOST}:{PORT}") + client.connect(HOST, PORT) + client.fit() diff --git a/GlobalModel.py b/GlobalModel.py @@ -33,7 +33,7 @@ class GlobalModel: def get_params(self): """Get the tensor form parameters of this model""" - return [p.data for p in self.net.parameters()] + return self.net.get_params() def fed_avg(net, num_clients, grads, lr): diff --git a/Server.py b/Server.py @@ -5,16 +5,23 @@ Author: Cody Lewis """ import socket +import pickle +import torch +import torch.nn as nn import GlobalModel class Server: """Federated learning server class""" - def __init__(self, num_in, num_out): + def __init__(self, num_in, num_out, port=5000): self.net = GlobalModel.GlobalModel(num_in, num_out) self.num_clients = 0 - self.address = ('', 5000) + self.address = ('', port) + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.address) + self.clients = [] def accept_client(self, s): """Accept a client and update the model accordingly""" @@ -25,22 +32,52 @@ class Server: def accept_clients(self, num_clients): """Accept some clients to the system""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(self.address) - s.listen(num_clients) - clients = [ - (c, addr) for c, addr in [ - self.accept_client(s) for _ in range(num_clients) - ] + self.socket.listen(num_clients) + self.clients.extend([ + (c, addr) for c, addr in [ + self.accept_client(self.socket) for _ in range(num_clients) ] - for c, addr in clients: - msg = c.recv(1024) - print(f"{addr} >> {msg}") - c.send(msg) - c.close() + ]) + for c, _ in self.clients: + c.send(pickle.dumps(self.net.get_params())) + + def fit(self, X, Y, epochs): + criterion = nn.BCELoss() + for e in range(epochs): + grads = dict() + for i, (c, _) in enumerate(self.clients): + c.send(b'OK') + grads[i] = pickle.loads(c.recv(4096)) + self.net.fit(1, grads) + print( + f"Epoch: {e + 1}/{epochs}, Loss: {criterion(server.net.predict(X), Y)}", + end="\r" + ) + print() + + def close(self): + for c, _ in self.clients: + c.close() + self.clients = [] + if __name__ == '__main__': - server = Server(2, 2) + PORT = 5000 + server = Server(2, 2, PORT) + print(f"Starting server on port {PORT}") + X = torch.tensor([ + [0, 0], + [0, 1], + [1, 0], + [1, 1] + ], dtype=torch.float32) + Y = torch.tensor([ + [1, 0], + [1, 0], + [1, 0], + [0, 1] + ], dtype=torch.float32) server.accept_clients(2) + server.fit(X, Y, 5000) + server.close() diff --git a/SoftMaxModel.py b/SoftMaxModel.py @@ -57,6 +57,10 @@ class SoftMaxModel(nn.Module): "data_count": len(x) } + def get_params(self): + """Get the tensor form parameters of this model""" + return [p.data for p in self.parameters()] + def copy_params(self, params): """Copy input parameters into self""" for p, t in zip(params, self.parameters()):