From 70ec00285128f8f9f5fa0e848950212e6a235d43 Mon Sep 17 00:00:00 2001 From: Justin Worthe Date: Fri, 14 Aug 2015 22:12:56 +0200 Subject: Added missing sensors --- src/brain/neural_network.cpp | 81 +++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 43 deletions(-) (limited to 'src/brain/neural_network.cpp') diff --git a/src/brain/neural_network.cpp b/src/brain/neural_network.cpp index d6a5e15..9061194 100644 --- a/src/brain/neural_network.cpp +++ b/src/brain/neural_network.cpp @@ -1,5 +1,6 @@ #include "brain/neural_network.h" #include "brain/neuron.h" +#include NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSensors, int numberOfOutputs) { @@ -11,7 +12,9 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSenso } for (int i=0; i(i); + auto output = std::make_shared(i); + _outputs.push_back(output); + _neurons[i] = output; } parseFile(std::move(networkConfigFile)); @@ -21,15 +24,17 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, std::vector { _biasNode = std::make_shared(); - for (int i=0; i(i); - sensor->setActivation(sensorInitialValues.at(i) ? 1 : 0); + sensor->setActivation(sensorInitialValues[i] ? 1 : 0); _sensors[i] = sensor; } for (int i=0; i(i); + auto output = std::make_shared(i); + _outputs.push_back(output); + _neurons[i] = output; } parseFile(std::move(networkConfigFile)); @@ -38,14 +43,18 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, std::vector void NeuralNetwork::parseFile(std::istream &&file) { double weight; - for (std::string src, dest; - file >> src && file >> dest && file >> weight; ) - { - char srcType = src.at(0); - int srcId = std::stoi(src.substr(1)); - char destType = dest.at(0); - int destId = std::stoi(dest.substr(1)); + char srcType; + int srcId; + int destId; + while (file.get(srcType) && + file >> srcId && + file.ignore(std::numeric_limits::max(), 'n') && + file >> destId && + file >> weight && + file.ignore(std::numeric_limits::max(), '\n')) + { + std::shared_ptr source; std::shared_ptr destination; switch (srcType) @@ -56,21 +65,11 @@ void NeuralNetwork::parseFile(std::istream &&file) case 'b': source = _biasNode; break; - case 'n': - source = findOrAddNeuron(srcId); - break; default: - throw 1; - } - switch (destType) - { - case 'n': - destination = findOrAddNeuron(destId); - break; - default: - throw 1; + source = findOrAddNeuron(srcId); } - + destination = findOrAddNeuron(destId); + addLink(source, destination, weight); } @@ -78,34 +77,30 @@ void NeuralNetwork::parseFile(std::istream &&file) void NeuralNetwork::addLink(std::shared_ptr source, std::shared_ptr destination, double weight) { - auto link = std::make_shared(source, weight); - destination->addInput(link); + std::unique_ptr link(new NeuralLink(source, weight)); + destination->addInput(std::move(link)); } std::shared_ptr NeuralNetwork::findOrAddSensor(int id) { - bool sensorExists = _sensors.count(id) > 0; - if (!sensorExists) + auto sensor = _sensors[id]; + if (!sensor) { - _sensors[id] = std::make_shared(id); + sensor = std::make_shared(id); + _sensors[id] = sensor; } - return _sensors.at(id); + return sensor; } std::shared_ptr NeuralNetwork::findOrAddNeuron(int id) { - bool isOutput = _outputs.count(id) > 0; - if (isOutput) - { - return _outputs.at(id); - } - - bool hiddenNeuronExists = _hiddenNodes.count(id) > 0; - if (!hiddenNeuronExists) + auto neuron = _neurons[id]; + if (!neuron) { - _hiddenNodes[id] = std::make_shared(id); + neuron = std::make_shared(id); + _neurons[id] = neuron; } - return _hiddenNodes.at(id); + return neuron; } void NeuralNetwork::setInput(int inputIndex, double activation) @@ -117,13 +112,13 @@ int NeuralNetwork::findMaxOutputIndex() const { double currentMaxActivation = 0; int currentMaxIndex = 0; - for (std::pair> outputPair : _outputs) + for (unsigned int i=0; i<_outputs.size(); ++i) { - double activation = outputPair.second->activation(); + double activation = _outputs[i]->activation(); if (activation >= currentMaxActivation) { currentMaxActivation = activation; - currentMaxIndex = outputPair.first; + currentMaxIndex = i; } } return currentMaxIndex; -- cgit v1.2.3