use crate::command::Command; use crate::game::{GameBoard, SimulationOutcome}; use crate::geometry::*; use std::cmp; use std::collections::HashMap; use std::ops::*; use time::{Duration, PreciseTime}; use rand; use rand::prelude::*; use arrayvec::ArrayVec; pub fn choose_move(state: &GameBoard, previous_root: Option, start_time: &PreciseTime, max_time: Duration) -> (Command, Node) { let mut root_node = match previous_root { None => Node { state: state.clone(), score_sum: ScoreSum::new(), player_score_sums: [HashMap::new(), HashMap::new()], unexplored: mcts_move_combo(state), children: HashMap::new(), }, Some(mut node) => { node.children.drain() .map(|(_k, n)| n) .find(|n| n.state == *state) .unwrap_or_else(|| { eprintln!("Previous round did not appear in the cache"); Node { state: state.clone(), score_sum: ScoreSum::new(), player_score_sums: [HashMap::new(), HashMap::new()], unexplored: mcts_move_combo(state), children: HashMap::new(), } }) } }; while start_time.to(PreciseTime::now()) < max_time { let _ = mcts(&mut root_node); } eprintln!("Number of simulations: {}", root_node.score_sum.visit_count); for (command, score_sum) in &root_node.player_score_sums[0] { eprintln!( "{} = {} ({} visits)", command, score_sum.avg().val, score_sum.visit_count ); } let chosen_command = best_player_move(&root_node); root_node.children.retain(|[c1, _], _| *c1 == chosen_command); (chosen_command, root_node) } pub struct Node { state: GameBoard, score_sum: ScoreSum, player_score_sums: [HashMap; 2], unexplored: Vec<[Command; 2]>, children: HashMap<[Command; 2], Node>, } #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] struct Score { val: f32, } impl AddAssign for Score { fn add_assign(&mut self, other: Self) { self.val = self.val + other.val; } } impl Div for Score { type Output = Self; fn div(self, other: u32) -> Self { Score { val: self.val / other as f32, } } } impl cmp::Eq for Score {} impl cmp::Ord for Score { fn cmp(&self, other: &Score) -> cmp::Ordering { self.val .partial_cmp(&other.val) .unwrap_or(cmp::Ordering::Equal) } } struct ScoreSum { sum: Score, visit_count: u32, } impl ScoreSum { fn new() -> ScoreSum { ScoreSum { sum: Score { val: 0. }, visit_count: 0, } } fn with_initial(score: Score) -> ScoreSum { ScoreSum { sum: score, visit_count: 1, } } fn avg(&self) -> Score { self.sum / self.visit_count } } impl AddAssign for ScoreSum { fn add_assign(&mut self, other: Score) { self.sum += other; self.visit_count = self.visit_count.saturating_add(1); } } fn mcts(node: &mut Node) -> Score { if node.state.outcome != SimulationOutcome::Continue { score(&node.state) } else if let Some(commands) = node.unexplored.pop() { let mut new_state = node.state.clone(); new_state.simulate(commands); let score = rollout(&new_state); // TODO: This could overshoot, trying to estimate from concluded game let unexplored = mcts_move_combo(&new_state); let new_node = Node { state: new_state, score_sum: ScoreSum::with_initial(score), player_score_sums: [HashMap::new(), HashMap::new()], unexplored, children: HashMap::new(), }; node.children.insert(commands, new_node); update(node, commands, score); score } else { let commands = choose_existing(node); let score = mcts( node.children .get_mut(&commands) .expect("The existing node hasn't been tried yet"), ); update(node, commands, score); score } } fn mcts_move_combo(state: &GameBoard) -> Vec<[Command; 2]> { let player_moves = valid_moves(state, 0); let opponent_moves = valid_moves(state, 1); debug_assert!(player_moves.len() > 0, "No player moves"); debug_assert!(player_moves.len() > 0, "No opponent moves"); let mut result = Vec::with_capacity(player_moves.len() * opponent_moves.len()); for p in &player_moves { for o in &opponent_moves { result.push([p.clone(), o.clone()]); } } result } fn best_player_move(node: &Node) -> Command { node.player_score_sums[0] .iter() .max_by_key(|(_command, score_sum)| score_sum.avg()) .map(|(command, _score_sum)| *command) .unwrap_or(Command::DoNothing) } fn score(state: &GameBoard) -> Score { Score { val: match state.outcome { SimulationOutcome::PlayerWon(0) => 500., SimulationOutcome::PlayerWon(1) => -500., _ => (state.players[0].score() - state.players[1].score()) as f32, } } } fn rollout(state: &GameBoard) -> Score { let mut s = state.clone(); let mut rng = rand::thread_rng(); while s.outcome == SimulationOutcome::Continue { let player_moves = rollout_moves(&s, 0); let opponent_moves = rollout_moves(&s, 1); s.simulate([ player_moves .choose(&mut rng) .cloned() .unwrap_or(Command::DoNothing), opponent_moves .choose(&mut rng) .cloned() .unwrap_or(Command::DoNothing), ]); } score(&s) } fn choose_existing(node: &Node) -> [Command; 2] { [choose_one_existing(node, 0), choose_one_existing(node, 1)] } fn choose_one_existing(node: &Node, player_index: usize) -> Command { let ln_n = (node.score_sum.visit_count as f32).ln(); let c = 100.; let multiplier = if player_index == 0 { 1. } else { -1. }; node.player_score_sums[player_index] .iter() .max_by_key(|(_command, score_sum)| { (multiplier * (score_sum.avg().val + c * (ln_n / score_sum.visit_count as f32).sqrt())) as i32 }) .map(|(command, _score_sum)| *command) .unwrap_or(Command::DoNothing) } fn update(node: &mut Node, commands: [Command; 2], score: Score) { *node.player_score_sums[0] .entry(commands[0]) .or_insert(ScoreSum::new()) += score; *node.player_score_sums[1] .entry(commands[1]) .or_insert(ScoreSum::new()) += score; node.score_sum += score; } // fn heuristic_moves(state: &GameBoard, player_index: usize) -> Vec { // let worm = state.players[player_index].active_worm(); // let shoots = state // .valid_shoot_commands(player_index, worm.position, worm.weapon_range); // let closest_powerup = state.powerups // .iter() // .min_by_key(|p| (p.position - worm.position).walking_distance()); // let average_player_position = Point2d { // x: state.players[player_index].worms // .iter() // .map(|w| w.position.x) // .sum::() / state.players[player_index].worms.len() as i8, // y: state.players[player_index].worms // .iter() // .map(|w| w.position.y) // .sum::() / state.players[player_index].worms.len() as i8 // }; // let closest_opponent = state.players[GameBoard::opponent(player_index)].worms // .iter() // .min_by_key(|w| (w.position - average_player_position).walking_distance()); // let mut commands = if !shoots.is_empty() { // // we're in combat now. Feel free to move anywhere. // let moves = state.valid_move_commands(player_index); // moves.iter().chain(shoots.iter()).cloned().collect() // } else if let Some(powerup) = closest_powerup { // // there are powerups! Let's go grab the closest one. // moves_towards(state, player_index, powerup.position) // } else if let Some(opponent) = closest_opponent { // // we're not currently in combat. Let's go find the closest worm. // moves_towards(state, player_index, opponent.position) // } else { // // this shouldn't happen // debug_assert!(false, "No valid heuristic moves"); // vec!() // }; // commands.push(Command::DoNothing); // commands // } // fn moves_towards(state: &GameBoard, player_index: usize, to: Point2d) -> Vec { // let distance = (to - state.players[player_index].active_worm().position).walking_distance(); // state.valid_move_commands(player_index) // .iter() // .filter(|c| match c { // Command::Move(p) | Command::Dig(p) => (to - *p).walking_distance() < distance, // _ => false // }) // .cloned() // .collect() // } fn rollout_moves(state: &GameBoard, player_index: usize) -> ArrayVec<[Command; 8]> { if let Some(worm) = state.players[player_index].active_worm() { let shoots = state.valid_shoot_commands(player_index, worm.position, worm.weapon_range); if !shoots.is_empty() { return shoots; } // TODO: More directed destruction movements? state.valid_move_commands(player_index) } else { [Command::DoNothing].iter().cloned().collect() } } fn valid_moves(state: &GameBoard, player_index: usize) -> ArrayVec<[Command; 17]> { if let Some(worm) = state.players[player_index].active_worm() { state.valid_shoot_commands(player_index, worm.position, worm.weapon_range) .iter() .chain(state.valid_move_commands(player_index).iter()) .cloned() .collect() } else { [Command::DoNothing].iter().cloned().collect() } }