From 2e6ecf423c8228ac8de4badf4fc2d037a876b7ff Mon Sep 17 00:00:00 2001 From: Justin Worthe Date: Sat, 1 Aug 2015 22:50:00 +0200 Subject: Reading network from file --- brain.nn | 4 ++ include/brain/neural_link.h | 2 +- include/brain/neural_network.h | 14 +++++- include/brain/neuron.h | 2 + include/game_state.h | 2 + include/spacebot.h | 6 ++- src/brain/neural_link.cpp | 4 +- src/brain/neural_network.cpp | 103 +++++++++++++++++++++++++++++++++++++++++ src/brain/neuron.cpp | 5 ++ src/game_state.cpp | 5 ++ src/spacebot.cpp | 28 +++++++---- test/neural_network.cpp | 8 ++-- 12 files changed, 162 insertions(+), 21 deletions(-) create mode 100644 brain.nn diff --git a/brain.nn b/brain.nn new file mode 100644 index 0000000..e0c854f --- /dev/null +++ b/brain.nn @@ -0,0 +1,4 @@ +s0 n3 0.5 +n3 n1 1 +b0 n0 0.4 +b0 n3 0.5 \ No newline at end of file diff --git a/include/brain/neural_link.h b/include/brain/neural_link.h index 42ee2f9..5a63ba4 100644 --- a/include/brain/neural_link.h +++ b/include/brain/neural_link.h @@ -7,7 +7,7 @@ class NeuralLink { public: - NeuralLink(double weight); + NeuralLink(std::shared_ptr input, double weight); double weightedActivation() const; private: diff --git a/include/brain/neural_network.h b/include/brain/neural_network.h index 308cb56..b2c441f 100644 --- a/include/brain/neural_network.h +++ b/include/brain/neural_network.h @@ -3,9 +3,12 @@ #include #include #include +#include #include "brain/neural_node.h" #include "brain/sensor.h" +#include "brain/bias_node.h" +#include "brain/neuron.h" class NeuralNetwork { @@ -19,8 +22,15 @@ public: int numberOfOutputs() const { return _outputs.size(); } private: - std::vector> _nodes; std::vector> _sensors; - std::vector> _outputs; + std::shared_ptr _biasNode; + std::vector> _hiddenNodes; + std::vector> _outputs; + + void parseFile(std::istream &&file); + + void addLink(std::shared_ptr source, std::shared_ptr destination, double weight); + std::shared_ptr findOrAddSensor(int id); + std::shared_ptr findOrAddNeuron(int id); }; diff --git a/include/brain/neuron.h b/include/brain/neuron.h index 1607cf0..810ce5b 100644 --- a/include/brain/neuron.h +++ b/include/brain/neuron.h @@ -13,6 +13,8 @@ public: virtual ~Neuron() {} virtual double activation() const; + void addInput(std::shared_ptr link); + private: std::vector> _inputLinks; double sigmoid(double input) const; diff --git a/include/game_state.h b/include/game_state.h index 905bbc9..482dd7d 100644 --- a/include/game_state.h +++ b/include/game_state.h @@ -20,6 +20,8 @@ public: const std::vector& missiles() const { return _missiles; } const std::vector& shields() const { return _shields; } const std::vector& spaceships() const { return _spaceships; } + + std::vector toBitArray() const; private: std::vector _aliens; diff --git a/include/spacebot.h b/include/spacebot.h index 9b89383..079b33e 100644 --- a/include/spacebot.h +++ b/include/spacebot.h @@ -9,8 +9,10 @@ public: Spacebot(std::string outputPath); void writeNextMove(); private: - std::string outputFilename; - GameState gameState; + std::string _outputFilename; + std::string _networkConfigFilename; + GameState _gameState; + void writeMove(const Move& move); Move chooseMove(); }; diff --git a/src/brain/neural_link.cpp b/src/brain/neural_link.cpp index d217236..f8d2b29 100644 --- a/src/brain/neural_link.cpp +++ b/src/brain/neural_link.cpp @@ -1,7 +1,7 @@ #include "brain/neural_link.h" -NeuralLink::NeuralLink(double weight) - :_weight(weight) +NeuralLink::NeuralLink(std::shared_ptr input, double weight) + :_input(input), _weight(weight) { } diff --git a/src/brain/neural_network.cpp b/src/brain/neural_network.cpp index 2d3b902..15eedca 100644 --- a/src/brain/neural_network.cpp +++ b/src/brain/neural_network.cpp @@ -3,6 +3,8 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSensors, int numberOfOutputs) { + _biasNode = std::make_shared(); + for (int i=0; i(i)); @@ -11,6 +13,107 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSenso { _outputs.push_back(std::make_shared(i)); } + + parseFile(std::move(networkConfigFile)); +} + +void NeuralNetwork::parseFile(std::istream &&file) +{ + double weight; + for (std::string src, dest; + file >> src && file >> dest && file >> weight; ) + { + char srcType = src.at(0); + int srcId = std::stoi(src.substr(1)); + char destType = dest.at(0); + int destId = std::stoi(dest.substr(1)); + + std::shared_ptr source; + std::shared_ptr destination; + switch (srcType) + { + case 's': + source = findOrAddSensor(srcId); + break; + case 'b': + source = _biasNode; + break; + case 'n': + source = findOrAddNeuron(srcId); + break; + default: + throw 1; + } + switch (destType) + { + case 'n': + destination = findOrAddNeuron(destId); + break; + default: + throw 1; + } + + addLink(source, destination, weight); + } + +} + +void NeuralNetwork::addLink(std::shared_ptr source, std::shared_ptr destination, double weight) +{ + auto link = std::make_shared(source, weight); + destination->addInput(link); +} + +std::shared_ptr NeuralNetwork::findOrAddSensor(int id) +{ + std::shared_ptr result; + for (auto node : _sensors) + { + if (node->id() == id) + { + result = node; + break; + } + } + if (!result) + { + result = std::make_shared(id); + _sensors.push_back(result); + } + return result; +} + +std::shared_ptr NeuralNetwork::findOrAddNeuron(int id) +{ + std::shared_ptr result; + for (auto node : _hiddenNodes) + { + if (node->id() == id) + { + result = node; + break; + } + } + if (result) + { + return result; + } + for (auto node : _outputs) + { + if (node->id() == id) + { + result = node; + break; + } + } + if (result) + { + return result; + } + + result = std::make_shared(id); + _hiddenNodes.push_back(result); + return result; } void NeuralNetwork::setInput(int inputIndex, double activation) diff --git a/src/brain/neuron.cpp b/src/brain/neuron.cpp index 8c2e47c..7ea02c6 100644 --- a/src/brain/neuron.cpp +++ b/src/brain/neuron.cpp @@ -22,3 +22,8 @@ double Neuron::activation() const } return sigmoid(activationSum); } + +void Neuron::addInput(std::shared_ptr link) +{ + _inputLinks.push_back(link); +} diff --git a/src/game_state.cpp b/src/game_state.cpp index d99ca12..5fed683 100644 --- a/src/game_state.cpp +++ b/src/game_state.cpp @@ -88,3 +88,8 @@ void GameState::logState() std::cout << "Spaceship" << spaceship.coords() << std::endl; } } + +std::vector GameState::toBitArray() const +{ + return std::vector(); +} diff --git a/src/spacebot.cpp b/src/spacebot.cpp index 5f87df9..1f8f2b8 100644 --- a/src/spacebot.cpp +++ b/src/spacebot.cpp @@ -1,11 +1,12 @@ #include "spacebot.h" #include "move_string_mapper.h" -#include +#include "brain/neural_network.h" #include Spacebot::Spacebot(std::string outputPath) - : outputFilename(outputPath+"/move.txt"), - gameState(std::ifstream(outputPath+"/map.txt")) + : _outputFilename(outputPath+"/move.txt"), + _networkConfigFilename("brain.nn"), + _gameState(std::ifstream(outputPath+"/map.txt")) { } @@ -17,17 +18,24 @@ void Spacebot::writeNextMove() Move Spacebot::chooseMove() { - int min = static_cast(Move::NOTHING); - int max = static_cast(Move::BUILD_SHIELD); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(min, max); - return static_cast(dis(gen)); + auto sensorInputs = _gameState.toBitArray(); + + NeuralNetwork network(std::ifstream(_networkConfigFilename), + sensorInputs.size(), + static_cast(Move::BUILD_SHIELD)); + + for (int i=0; i(moveInt); } void Spacebot::writeMove(const Move& move) { - std::ofstream resultStream(outputFilename); + std::ofstream resultStream(_outputFilename); resultStream << MoveStringMapper().toString(move) << std::endl; return; } diff --git a/test/neural_network.cpp b/test/neural_network.cpp index 6040c83..dd029d6 100644 --- a/test/neural_network.cpp +++ b/test/neural_network.cpp @@ -25,14 +25,14 @@ SCENARIO("network is read from istream") GIVEN("a valid config file") { std::stringstream file; - file << "s0 n2 0.5" << std::endl; - file << "n2 n1 1" << std::endl; + file << "s0 n3 0.5" << std::endl; + file << "n3 n1 1" << std::endl; file << "b0 n0 0.4" << std::endl; - file << "b0 n2 0.5" << std:: endl; + file << "b0 n3 0.5" << std:: endl; WHEN("the network is initialized") { - NeuralNetwork network(std::move(file), 1, 2); + NeuralNetwork network(std::move(file), 1, 3); THEN("the network is constructed correctly") { -- cgit v1.2.3