commit aa99436619df2f1131c53cc45e9de0593e262f11
parent 8e73a3446d2b11b2c82bf0e9a6f8f50e760e0197
Author: Cody Lewis <cody@codymlewis.com>
Date: Fri, 16 Oct 2020 17:04:52 +1100
Implemented network programming form of the system
Diffstat:
4 files changed, 113 insertions(+), 26 deletions(-)
diff --git a/Client.py b/Client.py
@@ -1,20 +1,66 @@
+"""
+Classes and functions for the client networking aspect of federated learning
+
+Author: Cody Lewis
+"""
+
import socket
+import pickle
+
+import torch
+import torch.nn as nn
import SoftMaxModel
class Client:
- def __init__(self, num_in, num_out):
- self.net = SoftMaxModel.SoftMaxModel(num_in, num_out)
+ def __init__(self, x, y):
+ self.net = SoftMaxModel.SoftMaxModel(len(x[0]), len(y[0]))
+ self.x = x
+ self.y = y
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
def connect(self, host, port):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- s.connect((host, port))
- s.sendall(b'Hello, there')
- print(s.recv(1024))
+ """Connect to the host:port federated learning server"""
+ self.socket.connect((host, port))
+ self.net.copy_params(pickle.loads(self.socket.recv(1024)))
+
+ def fit(self):
+ if self.socket.recv(1024) != b'OK':
+ pass
+ criterion = nn.BCELoss()
+ e = 0
+ while True:
+ e += 1
+ history, grads = self.net.fit(self.x, self.y, 1, verbose=False)
+ self.socket.sendall(pickle.dumps(grads))
+ print(
+ f"Epoch: {e}, Loss: {criterion(self.net(X), Y)}",
+ end="\r"
+ )
+ if self.socket.recv(1024) != b'OK':
+ break
+ # An improvement would be to save grads as a backlog and concurrent
+ # send them when the server is ready
+ print()
if __name__ == '__main__':
- client = Client(2, 2)
- client.connect('127.0.0.1', 5000)
+ X = torch.tensor([
+ [0, 0],
+ [0, 1],
+ [1, 0],
+ [1, 1]
+ ], dtype=torch.float32)
+ Y = torch.tensor([
+ [1, 0],
+ [1, 0],
+ [1, 0],
+ [0, 1]
+ ], dtype=torch.float32)
+ client = Client(X, Y)
+ HOST, PORT = '127.0.0.1', 5000
+ print(f"Connecting to {HOST}:{PORT}")
+ client.connect(HOST, PORT)
+ client.fit()
diff --git a/GlobalModel.py b/GlobalModel.py
@@ -33,7 +33,7 @@ class GlobalModel:
def get_params(self):
"""Get the tensor form parameters of this model"""
- return [p.data for p in self.net.parameters()]
+ return self.net.get_params()
def fed_avg(net, num_clients, grads, lr):
diff --git a/Server.py b/Server.py
@@ -5,16 +5,23 @@ Author: Cody Lewis
"""
import socket
+import pickle
+import torch
+import torch.nn as nn
import GlobalModel
class Server:
"""Federated learning server class"""
- def __init__(self, num_in, num_out):
+ def __init__(self, num_in, num_out, port=5000):
self.net = GlobalModel.GlobalModel(num_in, num_out)
self.num_clients = 0
- self.address = ('', 5000)
+ self.address = ('', port)
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.socket.bind(self.address)
+ self.clients = []
def accept_client(self, s):
"""Accept a client and update the model accordingly"""
@@ -25,22 +32,52 @@ class Server:
def accept_clients(self, num_clients):
"""Accept some clients to the system"""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- s.bind(self.address)
- s.listen(num_clients)
- clients = [
- (c, addr) for c, addr in [
- self.accept_client(s) for _ in range(num_clients)
- ]
+ self.socket.listen(num_clients)
+ self.clients.extend([
+ (c, addr) for c, addr in [
+ self.accept_client(self.socket) for _ in range(num_clients)
]
- for c, addr in clients:
- msg = c.recv(1024)
- print(f"{addr} >> {msg}")
- c.send(msg)
- c.close()
+ ])
+ for c, _ in self.clients:
+ c.send(pickle.dumps(self.net.get_params()))
+
+ def fit(self, X, Y, epochs):
+ criterion = nn.BCELoss()
+ for e in range(epochs):
+ grads = dict()
+ for i, (c, _) in enumerate(self.clients):
+ c.send(b'OK')
+ grads[i] = pickle.loads(c.recv(4096))
+ self.net.fit(1, grads)
+ print(
+ f"Epoch: {e + 1}/{epochs}, Loss: {criterion(server.net.predict(X), Y)}",
+ end="\r"
+ )
+ print()
+
+ def close(self):
+ for c, _ in self.clients:
+ c.close()
+ self.clients = []
+
if __name__ == '__main__':
- server = Server(2, 2)
+ PORT = 5000
+ server = Server(2, 2, PORT)
+ print(f"Starting server on port {PORT}")
+ X = torch.tensor([
+ [0, 0],
+ [0, 1],
+ [1, 0],
+ [1, 1]
+ ], dtype=torch.float32)
+ Y = torch.tensor([
+ [1, 0],
+ [1, 0],
+ [1, 0],
+ [0, 1]
+ ], dtype=torch.float32)
server.accept_clients(2)
+ server.fit(X, Y, 5000)
+ server.close()
diff --git a/SoftMaxModel.py b/SoftMaxModel.py
@@ -57,6 +57,10 @@ class SoftMaxModel(nn.Module):
"data_count": len(x)
}
+ def get_params(self):
+ """Get the tensor form parameters of this model"""
+ return [p.data for p in self.parameters()]
+
def copy_params(self, params):
"""Copy input parameters into self"""
for p, t in zip(params, self.parameters()):