summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Worthe <justin@worthe-it.co.za>2019-05-12 00:58:14 +0200
committerJustin Worthe <justin@worthe-it.co.za>2019-05-12 00:58:14 +0200
commit3d1676842e20c90bb5599daa2caefdea2bbf9fe8 (patch)
treec5c45ef35064fad01ed690be8555273224765e67
parentebe7d5cd5cc42d8f3f02ca926f4b920ada03765f (diff)
Outline of MCTS
-rw-r--r--Cargo.lock18
-rw-r--r--Cargo.toml6
-rw-r--r--src/command.rs2
-rw-r--r--src/game.rs35
-rw-r--r--src/geometry/direction.rs2
-rw-r--r--src/main.rs9
-rw-r--r--src/strategy.rs144
7 files changed, 178 insertions, 38 deletions
diff --git a/Cargo.lock b/Cargo.lock
index aaf9d87..19a813f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -170,6 +170,11 @@ dependencies = [
]
[[package]]
+name = "redox_syscall"
+version = "0.1.54"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+
+[[package]]
name = "ryu"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -211,6 +216,7 @@ dependencies = [
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.90 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)",
+ "time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -224,6 +230,16 @@ dependencies = [
]
[[package]]
+name = "time"
+version = "0.1.42"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "libc 0.2.51 (registry+https://github.com/rust-lang/crates.io-index)",
+ "redox_syscall 0.1.54 (registry+https://github.com/rust-lang/crates.io-index)",
+ "winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "unicode-xid"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -270,11 +286,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rand_pcg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "abf9b09b01790cfe0364f52bf32995ea3c39f4d2dd011eac241d2914146d0b44"
"checksum rand_xorshift 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cbf7e9e623549b0e21f6e97cf8ecf247c1a8fd2e8a992ae265314300b2455d5c"
"checksum rdrand 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2"
+"checksum redox_syscall 0.1.54 (registry+https://github.com/rust-lang/crates.io-index)" = "12229c14a0f65c4f1cb046a3b52047cdd9da1f4b30f8a39c5063c8bae515e252"
"checksum ryu 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "eb9e9b8cde282a9fe6a42dd4681319bfb63f121b8a8ee9439c6f4107e58a46f7"
"checksum serde 1.0.90 (registry+https://github.com/rust-lang/crates.io-index)" = "aa5f7c20820475babd2c077c3ab5f8c77a31c15e16ea38687b4c02d3e48680f4"
"checksum serde_derive 1.0.90 (registry+https://github.com/rust-lang/crates.io-index)" = "58fc82bec244f168b23d1963b45c8bf5726e9a15a9d146a067f9081aeed2de79"
"checksum serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)" = "5a23aa71d4a4d43fdbfaac00eff68ba8a06a51759a89ac3304323e800c4dd40d"
"checksum syn 0.15.31 (registry+https://github.com/rust-lang/crates.io-index)" = "d2b4cfac95805274c6afdb12d8f770fa2d27c045953e7b630a81801953699a9a"
+"checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f"
"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
"checksum winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)" = "f10e386af2b13e47c89e7236a7a14a086791a2b88ebad6df9bf42040195cf770"
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
diff --git a/Cargo.toml b/Cargo.toml
index 0848b66..af0623b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -7,5 +7,9 @@ edition = "2018"
serde = { version = "1.0.90", features = ["derive"] }
serde_json = "1.0.39"
rand = "0.6.5"
+time = "0.1.42"
num-traits = "0.2.6"
-arrayvec = "0.4.10" \ No newline at end of file
+arrayvec = "0.4.10"
+
+[profile.release]
+debug = true \ No newline at end of file
diff --git a/src/command.rs b/src/command.rs
index a510120..bca0f38 100644
--- a/src/command.rs
+++ b/src/command.rs
@@ -1,7 +1,7 @@
use std::fmt;
use crate::geometry::Direction;
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Command {
Move(i8, i8),
Dig(i8, i8),
diff --git a/src/game.rs b/src/game.rs
index be2dcce..b6a051f 100644
--- a/src/game.rs
+++ b/src/game.rs
@@ -10,6 +10,7 @@ pub struct GameBoard {
pub players: [Player; 2],
pub powerups: ArrayVec<[Powerup; 2]>,
pub map: Map,
+ pub outcome: SimulationOutcome
}
#[derive(Debug, PartialEq, Eq, Clone)]
@@ -85,7 +86,8 @@ impl GameBoard {
powerups: json.map.iter().flatten().filter_map(|c| {
c.powerup.clone().map(|p| Powerup::Health(Point2d::new(c.x, c.y), p.value))
}).collect(),
- map
+ map,
+ outcome: SimulationOutcome::Continue
}
}
@@ -201,12 +203,41 @@ impl GameBoard {
player.active_worm = (player.active_worm + 1) % player.worms.len();
}
- match (self.players[0].worms.len(), self.players[1].worms.len()) {
+ self.outcome = match (self.players[0].worms.len(), self.players[1].worms.len()) {
(0, 0) => SimulationOutcome::Draw,
(_, 0) => SimulationOutcome::PlayerWon(0),
(0, _) => SimulationOutcome::PlayerWon(1),
_ => SimulationOutcome::Continue
+ };
+
+ self.outcome
+ }
+
+ pub fn find_target(&self, center: Point2d<i8>, dir: Direction, weapon_range: u8) -> Option<&Worm> {
+ let diff = dir.as_vec();
+
+ let range = if dir.is_diagonal() {
+ ((weapon_range as f32 + 1.) / 2f32.sqrt()).floor() as i8
+ } else {
+ weapon_range as i8
+ };
+
+ let mut target_worm: Option<&Worm> = None;
+ for distance in 1..=range {
+ let target = center + diff * distance;
+ match self.map.at(target) {
+ Some(false) => {
+ target_worm = self.players.iter()
+ .flat_map(|p| p.worms.iter())
+ .find(|w| w.position == target);
+ if target_worm.is_some() {
+ break;
+ }
+ },
+ _ => break
+ }
}
+ target_worm
}
}
diff --git a/src/geometry/direction.rs b/src/geometry/direction.rs
index 119aeee..a7d9861 100644
--- a/src/geometry/direction.rs
+++ b/src/geometry/direction.rs
@@ -1,7 +1,7 @@
use std::fmt;
use crate::geometry::vec::Vec2d;
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Direction {
North,
NorthEast,
diff --git a/src/main.rs b/src/main.rs
index d6d9a4c..c24565a 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,14 +2,19 @@ use std::io::prelude::*;
use std::io::stdin;
use std::path::Path;
+use time::{Duration, PreciseTime};
+
use steam_powered_wyrm::command::Command;
use steam_powered_wyrm::strategy::choose_move;
use steam_powered_wyrm::json;
use steam_powered_wyrm::game;
fn main() {
+ let max_time = Duration::milliseconds(950);
let mut game_board = None;
for line in stdin().lock().lines() {
+ let start_time = PreciseTime::now();
+
let round_number = line.expect("Failed to read line from stdin: {}");
let command =
@@ -18,13 +23,13 @@ fn main() {
match &mut game_board {
None => {
let new_board = game::GameBoard::new(json_state);
- let command = choose_move(&new_board);
+ let command = choose_move(&new_board, &start_time, max_time);
game_board = Some(new_board);
command
},
Some(game_board) => {
game_board.update(json_state);
- choose_move(&game_board)
+ choose_move(&game_board, &start_time, max_time)
}
}
},
diff --git a/src/strategy.rs b/src/strategy.rs
index dd15854..ce65e54 100644
--- a/src/strategy.rs
+++ b/src/strategy.rs
@@ -1,52 +1,132 @@
use crate::command::Command;
-use crate::game::GameBoard;
+use crate::game::{GameBoard, SimulationOutcome};
use crate::geometry::*;
+use std::ops::*;
+use std::collections::HashMap;
+use time::{Duration, PreciseTime};
+
struct GameTree {
state: GameBoard,
next_states: Vec<([Command; 2], GameTree)>
}
-pub fn choose_move(state: &GameBoard) -> Command {
- let mut root = GameTree {
+pub fn choose_move(state: &GameBoard, start_time: &PreciseTime, max_time: Duration) -> Command {
+ let mut root_node = Node {
state: state.clone(),
- next_states: Vec::new()
+ score_sum: Score { val: 0 },
+ visit_count: 0,
+ children: HashMap::new()
};
- let mut last_depth = vec!(&mut root);
-
- for depth in 0.. {
- println!("Trying depth {}", depth);
- println!("{} wide", last_depth.len());
- let mut next_depth = Vec::new();
- for mut tree in last_depth {
- populate_next_states(&mut tree);
- for x in &mut tree.next_states {
- next_depth.push(&mut x.1);
- }
- }
- last_depth = next_depth;
+ while start_time.to(PreciseTime::now()) < max_time {
+ let _ = mcts(&mut root_node);
}
-
- Command::DoNothing
+
+ root_node
+ .children
+ .iter()
+ .max_by_key(|(_k, v)| v.score())
+ .map(|(k, _v)| k[0])
+ .unwrap_or(Command::DoNothing)
+}
+
+struct Node {
+ state: GameBoard,
+ score_sum: Score,
+ visit_count: u32,
+ children: HashMap<[Command; 2], Node>
+}
+
+impl Node {
+ fn score(&self) -> Score {
+ self.score_sum / self.visit_count
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
+struct Score {
+ val: i32
}
-fn populate_next_states(tree: &mut GameTree) {
- let valid_player_moves = valid_moves(&tree.state, 0);
- let valid_opponent_moves = valid_moves(&tree.state, 1);
- for player_move in valid_player_moves {
- for opponent_move in &valid_opponent_moves {
- let commands = [player_move, *opponent_move];
- let mut new_state = tree.state.clone();
- let _ = new_state.simulate(commands);
- tree.next_states.push((commands, GameTree {
- state: new_state,
- next_states: Vec::new()
- }));
+impl AddAssign for Score {
+ fn add_assign(&mut self, other: Self) {
+ self.val += other.val;
+ }
+}
+
+impl Div<u32> for Score {
+ type Output = Self;
+ fn div(self, other: u32) -> Self {
+ Score {
+ val: self.val / other as i32
}
}
}
+fn mcts(node: &mut Node) -> Score {
+ if node.state.outcome != SimulationOutcome::Continue {
+ score(&node.state)
+ } else if has_unsimulated_outcomes(node) {
+ let commands = choose_unsimulated(&node);
+
+ let mut new_state = node.state.clone();
+ new_state.simulate(commands);
+ let score = rollout(&new_state);
+
+ let new_node = Node {
+ state: new_state,
+ score_sum: score,
+ visit_count: 1,
+ children: HashMap::new()
+ };
+ node.children.insert(commands, new_node);
+
+ update(node, commands, score);
+ score
+ } else {
+ let commands = select(node);
+ let score = mcts(node.children.get_mut(&commands).unwrap());
+ update(node, commands, score);
+ score
+ }
+}
+
+fn score(state: &GameBoard) -> Score {
+ // TODO
+ Score { val: 0 }
+}
+
+fn has_unsimulated_outcomes(node: &Node) -> bool {
+ // TODO
+ false
+}
+
+fn choose_unsimulated(node: &Node) -> [Command; 2] {
+ // TODO
+ [
+ Command::DoNothing,
+ Command::DoNothing
+ ]
+}
+
+fn rollout(state: &GameBoard) -> Score {
+ // TODO
+ Score { val: 0 }
+}
+
+fn select(node: &Node) -> [Command; 2] {
+ // TODO
+ [
+ Command::DoNothing,
+ Command::DoNothing
+ ]
+}
+
+fn update(node: &mut Node, commands: [Command; 2], score: Score) {
+ // TODO
+}
+
fn valid_moves(state: &GameBoard, player_index: usize) -> Vec<Command> {
let worm = state.players[player_index].active_worm();
@@ -60,9 +140,11 @@ fn valid_moves(state: &GameBoard, player_index: usize) -> Vec<Command> {
})
.collect::<Vec<_>>();
let mut shoots = Direction::all().iter()
+ .filter(|dir| state.find_target(worm.position, **dir, worm.weapon_range).is_some())
.map(|d| Command::Shoot(*d))
.collect::<Vec<_>>();
moves.append(&mut shoots);
+ moves.retain(|m| *m != Command::DoNothing);
moves
}