commit 056c00dcb314dbba926b5af50b23f146451188ad
parent 8a6e5aac037717e6fd0e00c2016f42d64d583240
Author: Cody Lewis <cody@codymlewis.com>
Date: Wed, 8 Apr 2020 20:31:00 +1000
Improved training code
Diffstat:
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/network_controller.py b/network_controller.py
@@ -256,8 +256,8 @@ def launch():
if __name__ == '__main__':
if "--train" in sys.argv:
- data, labels_bin = (lambda x: (x[:, :6], x[:, 6]))(np.loadtxt("training_data.txt"))
- labels = np.array([[1, 0] if l == 0 else [0, 1] for l in labels_bin])
+ data, lbls = (lambda x: (x[:, :6], x[:, 6]))(np.loadtxt("training_data.txt"))
+ labels = np.array([[1, 0] if l == 0 else [0, 1] for l in lbls])
inputs = keras.Input(shape=(6,))
x = keras.layers.Dense(100, activation=tf.nn.relu)(inputs)
x = keras.layers.Dense(100, activation=tf.nn.relu)(x)
@@ -277,8 +277,9 @@ if __name__ == '__main__':
validation_split=0.2,
callbacks=[keras.callbacks.EarlyStopping(patience=5)]
)
- print("Reached loss: {}".format(history.history['loss'][-1]))
- model.save("model.h5")
- print("Saved model as model.h5")
+ print(f"Reached loss: {history.history['loss'][-1]}")
+ fn = "model.h5"
+ model.save(fn)
+ print("Saved model as {fn}")
else:
boot(["network_controller"])