diff options
author | Justin Worthe <justin@worthe-it.co.za> | 2019-05-14 00:45:49 +0200 |
---|---|---|
committer | Justin Worthe <justin@worthe-it.co.za> | 2019-05-14 00:45:49 +0200 |
commit | dcbd04dfdc6dd6dac88020d3a51f23fa5905c356 (patch) | |
tree | dc02ab4951f01f6c1561928390e848f8f415ecac /src/strategy.rs | |
parent | 652242e584ee2b7cfb3021d570a63e57cfa52773 (diff) |
Filled in the rest of the MCTS
Problem: The current random things isn't actually finding any
victorious end states. This game easily meanders if it's played
without purpose.
Diffstat (limited to 'src/strategy.rs')
-rw-r--r-- | src/strategy.rs | 101 |
1 files changed, 73 insertions, 28 deletions
diff --git a/src/strategy.rs b/src/strategy.rs index db4409e..d6f92a6 100644 --- a/src/strategy.rs +++ b/src/strategy.rs @@ -2,10 +2,14 @@ use crate::command::Command; use crate::game::{GameBoard, SimulationOutcome}; use crate::geometry::*; +use std::cmp; use std::ops::*; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap}; use time::{Duration, PreciseTime}; +use rand; +use rand::prelude::*; + pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Duration) -> Command { let mut root_node = Node { state: state.clone(), @@ -19,6 +23,11 @@ pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Durati let _ = mcts(&mut root_node); } + eprintln!("Number of simulations: {}", root_node.score_sum.visit_count); + for (command, score_sum) in &root_node.player_score_sums[0] { + eprintln!("{} = {} ({} visits)", command, score_sum.avg().val, score_sum.visit_count); + } + best_player_move(&root_node) } @@ -30,41 +39,42 @@ struct Node { children: HashMap<[Command; 2], Node>, } -impl Node { - fn score(&self) -> Score { - self.score_sum.avg() - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] struct Score { - val: i32 + val: f32 } impl AddAssign for Score { fn add_assign(&mut self, other: Self) { - self.val += other.val; + self.val = self.val + other.val; } } -impl Div<i32> for Score { +impl Div<u32> for Score { type Output = Self; - fn div(self, other: i32) -> Self { + fn div(self, other: u32) -> Self { Score { - val: self.val / other + val: self.val / other as f32 } } } +impl cmp::Eq for Score {} +impl cmp::Ord for Score { + fn cmp(&self, other: &Score) -> cmp::Ordering { + self.val.partial_cmp(&other.val).unwrap_or(cmp::Ordering::Equal) + } +} + struct ScoreSum { sum: Score, - visit_count: i32 + visit_count: u32 } impl ScoreSum { fn new() -> ScoreSum { ScoreSum { - sum: Score { val: 0 }, + sum: Score { val: 0. }, visit_count: 0 } } @@ -82,7 +92,7 @@ impl ScoreSum { impl AddAssign<Score> for ScoreSum { fn add_assign(&mut self, other: Score) { self.sum += other; - self.visit_count += 1; + self.visit_count = self.visit_count.saturating_add(1); } } @@ -108,43 +118,78 @@ fn mcts(node: &mut Node) -> Score { score } else { let commands = choose_existing(node); - let score = mcts(node.children.get_mut(&commands).unwrap()); + let score = mcts(node.children.get_mut(&commands).expect("The existing node hasn't been tried yet")); update(node, commands, score); score } } fn best_player_move(node: &Node) -> Command { - // TODO, use player_score_sums? node - .children + .player_score_sums[0] .iter() - .max_by_key(|(_k, v)| v.score()) - .map(|(k, _v)| k[0]) + .max_by_key(|(_command, score_sum)| { + score_sum.avg() + }) + .map(|(command, _score_sum)| *command) .unwrap_or(Command::DoNothing) } fn score(state: &GameBoard) -> Score { + let mutiplier = match state.outcome { + SimulationOutcome::PlayerWon(_) => 100., + _ => 1. + }; Score { - val: state.players[0].health() - state.players[1].health() + val: mutiplier * (state.players[0].health() - state.players[1].health()) as f32 } } fn rollout(state: &GameBoard) -> Score { - // TODO - Score { val: 0 } + let mut s = state.clone(); + let mut rng = rand::thread_rng(); + while s.outcome == SimulationOutcome::Continue { + let player_moves = valid_moves(&s, 0); + let opponent_moves = valid_moves(&s, 1); + + s.simulate([ + player_moves.choose(&mut rng).cloned().unwrap_or(Command::DoNothing), + opponent_moves.choose(&mut rng).cloned().unwrap_or(Command::DoNothing) + ]); + } + + score(&s) } fn choose_existing(node: &Node) -> [Command; 2] { - // TODO [ - Command::DoNothing, - Command::DoNothing + choose_one_existing(node, 0), + choose_one_existing(node, 1) ] } +fn choose_one_existing(node: &Node, player_index: usize) -> Command { + let ln_n = (node.score_sum.visit_count as f32).ln(); + let c = 100.; + let multiplier = if player_index == 0 { + 1. + } else { + -1. + }; + node.player_score_sums[player_index] + .iter() + .max_by_key(|(_command, score_sum)| { + (multiplier * (score_sum.avg().val + c * (ln_n / score_sum.visit_count as f32).sqrt())) as i32 + }) + .map(|(command, _score_sum)| *command) + .unwrap_or(Command::DoNothing) +} + + fn update(node: &mut Node, commands: [Command; 2], score: Score) { - // TODO + *node.player_score_sums[0].entry(commands[0]).or_insert(ScoreSum::new()) += score; + *node.player_score_sums[1].entry(commands[1]).or_insert(ScoreSum::new()) += score; + node.score_sum += score; } fn valid_move_combo(state: &GameBoard) -> Vec<[Command; 2]> { |