summaryrefslogtreecommitdiff
path: root/src/strategy.rs
diff options
context:
space:
mode:
authorJustin Worthe <justin@worthe-it.co.za>2019-05-14 00:45:49 +0200
committerJustin Worthe <justin@worthe-it.co.za>2019-05-14 00:45:49 +0200
commitdcbd04dfdc6dd6dac88020d3a51f23fa5905c356 (patch)
treedc02ab4951f01f6c1561928390e848f8f415ecac /src/strategy.rs
parent652242e584ee2b7cfb3021d570a63e57cfa52773 (diff)
Filled in the rest of the MCTS
Problem: The current random things isn't actually finding any victorious end states. This game easily meanders if it's played without purpose.
Diffstat (limited to 'src/strategy.rs')
-rw-r--r--src/strategy.rs101
1 files changed, 73 insertions, 28 deletions
diff --git a/src/strategy.rs b/src/strategy.rs
index db4409e..d6f92a6 100644
--- a/src/strategy.rs
+++ b/src/strategy.rs
@@ -2,10 +2,14 @@ use crate::command::Command;
use crate::game::{GameBoard, SimulationOutcome};
use crate::geometry::*;
+use std::cmp;
use std::ops::*;
-use std::collections::{HashMap, HashSet};
+use std::collections::{HashMap};
use time::{Duration, PreciseTime};
+use rand;
+use rand::prelude::*;
+
pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Duration) -> Command {
let mut root_node = Node {
state: state.clone(),
@@ -19,6 +23,11 @@ pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Durati
let _ = mcts(&mut root_node);
}
+ eprintln!("Number of simulations: {}", root_node.score_sum.visit_count);
+ for (command, score_sum) in &root_node.player_score_sums[0] {
+ eprintln!("{} = {} ({} visits)", command, score_sum.avg().val, score_sum.visit_count);
+ }
+
best_player_move(&root_node)
}
@@ -30,41 +39,42 @@ struct Node {
children: HashMap<[Command; 2], Node>,
}
-impl Node {
- fn score(&self) -> Score {
- self.score_sum.avg()
- }
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
struct Score {
- val: i32
+ val: f32
}
impl AddAssign for Score {
fn add_assign(&mut self, other: Self) {
- self.val += other.val;
+ self.val = self.val + other.val;
}
}
-impl Div<i32> for Score {
+impl Div<u32> for Score {
type Output = Self;
- fn div(self, other: i32) -> Self {
+ fn div(self, other: u32) -> Self {
Score {
- val: self.val / other
+ val: self.val / other as f32
}
}
}
+impl cmp::Eq for Score {}
+impl cmp::Ord for Score {
+ fn cmp(&self, other: &Score) -> cmp::Ordering {
+ self.val.partial_cmp(&other.val).unwrap_or(cmp::Ordering::Equal)
+ }
+}
+
struct ScoreSum {
sum: Score,
- visit_count: i32
+ visit_count: u32
}
impl ScoreSum {
fn new() -> ScoreSum {
ScoreSum {
- sum: Score { val: 0 },
+ sum: Score { val: 0. },
visit_count: 0
}
}
@@ -82,7 +92,7 @@ impl ScoreSum {
impl AddAssign<Score> for ScoreSum {
fn add_assign(&mut self, other: Score) {
self.sum += other;
- self.visit_count += 1;
+ self.visit_count = self.visit_count.saturating_add(1);
}
}
@@ -108,43 +118,78 @@ fn mcts(node: &mut Node) -> Score {
score
} else {
let commands = choose_existing(node);
- let score = mcts(node.children.get_mut(&commands).unwrap());
+ let score = mcts(node.children.get_mut(&commands).expect("The existing node hasn't been tried yet"));
update(node, commands, score);
score
}
}
fn best_player_move(node: &Node) -> Command {
- // TODO, use player_score_sums?
node
- .children
+ .player_score_sums[0]
.iter()
- .max_by_key(|(_k, v)| v.score())
- .map(|(k, _v)| k[0])
+ .max_by_key(|(_command, score_sum)| {
+ score_sum.avg()
+ })
+ .map(|(command, _score_sum)| *command)
.unwrap_or(Command::DoNothing)
}
fn score(state: &GameBoard) -> Score {
+ let mutiplier = match state.outcome {
+ SimulationOutcome::PlayerWon(_) => 100.,
+ _ => 1.
+ };
Score {
- val: state.players[0].health() - state.players[1].health()
+ val: mutiplier * (state.players[0].health() - state.players[1].health()) as f32
}
}
fn rollout(state: &GameBoard) -> Score {
- // TODO
- Score { val: 0 }
+ let mut s = state.clone();
+ let mut rng = rand::thread_rng();
+ while s.outcome == SimulationOutcome::Continue {
+ let player_moves = valid_moves(&s, 0);
+ let opponent_moves = valid_moves(&s, 1);
+
+ s.simulate([
+ player_moves.choose(&mut rng).cloned().unwrap_or(Command::DoNothing),
+ opponent_moves.choose(&mut rng).cloned().unwrap_or(Command::DoNothing)
+ ]);
+ }
+
+ score(&s)
}
fn choose_existing(node: &Node) -> [Command; 2] {
- // TODO
[
- Command::DoNothing,
- Command::DoNothing
+ choose_one_existing(node, 0),
+ choose_one_existing(node, 1)
]
}
+fn choose_one_existing(node: &Node, player_index: usize) -> Command {
+ let ln_n = (node.score_sum.visit_count as f32).ln();
+ let c = 100.;
+ let multiplier = if player_index == 0 {
+ 1.
+ } else {
+ -1.
+ };
+ node.player_score_sums[player_index]
+ .iter()
+ .max_by_key(|(_command, score_sum)| {
+ (multiplier * (score_sum.avg().val + c * (ln_n / score_sum.visit_count as f32).sqrt())) as i32
+ })
+ .map(|(command, _score_sum)| *command)
+ .unwrap_or(Command::DoNothing)
+}
+
+
fn update(node: &mut Node, commands: [Command; 2], score: Score) {
- // TODO
+ *node.player_score_sums[0].entry(commands[0]).or_insert(ScoreSum::new()) += score;
+ *node.player_score_sums[1].entry(commands[1]).or_insert(ScoreSum::new()) += score;
+ node.score_sum += score;
}
fn valid_move_combo(state: &GameBoard) -> Vec<[Command; 2]> {