1 #[cfg(test)] 2 mod tests; 3 4 use crate::std::fmt; 5 use crate::std::sync::{Condvar, Mutex}; 6 7 /// A barrier enables multiple threads to synchronize the beginning 8 /// of some computation. 9 /// 10 /// # Examples 11 /// 12 /// ``` 13 /// use std::sync::{Arc, Barrier}; 14 /// use std::thread; 15 /// 16 /// let n = 10; 17 /// let mut handles = Vec::with_capacity(n); 18 /// let barrier = Arc::new(Barrier::new(n)); 19 /// for _ in 0..n { 20 /// let c = Arc::clone(&barrier); 21 /// // The same messages will be printed together. 22 /// // You will NOT see any interleaving. 23 /// handles.push(thread::spawn(move|| { 24 /// println!("before wait"); 25 /// c.wait(); 26 /// println!("after wait"); 27 /// })); 28 /// } 29 /// // Wait for other threads to finish. 30 /// for handle in handles { 31 /// handle.join().unwrap(); 32 /// } 33 /// ``` 34 pub struct Barrier { 35 lock: Mutex<BarrierState>, 36 cvar: Condvar, 37 num_threads: usize, 38 } 39 40 // The inner state of a double barrier 41 struct BarrierState { 42 count: usize, 43 generation_id: usize, 44 } 45 46 /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads 47 /// in the [`Barrier`] have rendezvoused. 48 /// 49 /// # Examples 50 /// 51 /// ``` 52 /// use std::sync::Barrier; 53 /// 54 /// let barrier = Barrier::new(1); 55 /// let barrier_wait_result = barrier.wait(); 56 /// ``` 57 pub struct BarrierWaitResult(bool); 58 59 impl fmt::Debug for Barrier { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 61 f.debug_struct("Barrier").finish_non_exhaustive() 62 } 63 } 64 65 impl Barrier { 66 /// Creates a new barrier that can block a given number of threads. 67 /// 68 /// A barrier will block `n`-1 threads which call [`wait()`] and then wake 69 /// up all threads at once when the `n`th thread calls [`wait()`]. 70 /// 71 /// [`wait()`]: Barrier::wait 72 /// 73 /// # Examples 74 /// 75 /// ``` 76 /// use std::sync::Barrier; 77 /// 78 /// let barrier = Barrier::new(10); 79 /// ``` 80 #[must_use] new(n: usize) -> Barrier81 pub fn new(n: usize) -> Barrier { 82 Barrier { 83 lock: Mutex::new(BarrierState { 84 count: 0, 85 generation_id: 0, 86 }), 87 cvar: Condvar::new(), 88 num_threads: n, 89 } 90 } 91 92 /// Blocks the current thread until all threads have rendezvoused here. 93 /// 94 /// Barriers are re-usable after all threads have rendezvoused once, and can 95 /// be used continuously. 96 /// 97 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that 98 /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning 99 /// from this function, and all other threads will receive a result that 100 /// will return `false` from [`BarrierWaitResult::is_leader()`]. 101 /// 102 /// # Examples 103 /// 104 /// ``` 105 /// use std::sync::{Arc, Barrier}; 106 /// use std::thread; 107 /// 108 /// let n = 10; 109 /// let mut handles = Vec::with_capacity(n); 110 /// let barrier = Arc::new(Barrier::new(n)); 111 /// for _ in 0..n { 112 /// let c = Arc::clone(&barrier); 113 /// // The same messages will be printed together. 114 /// // You will NOT see any interleaving. 115 /// handles.push(thread::spawn(move|| { 116 /// println!("before wait"); 117 /// c.wait(); 118 /// println!("after wait"); 119 /// })); 120 /// } 121 /// // Wait for other threads to finish. 122 /// for handle in handles { 123 /// handle.join().unwrap(); 124 /// } 125 /// ``` wait(&self) -> BarrierWaitResult126 pub fn wait(&self) -> BarrierWaitResult { 127 let mut lock = self.lock.lock().unwrap(); 128 let local_gen = lock.generation_id; 129 lock.count += 1; 130 if lock.count < self.num_threads { 131 let _guard = self 132 .cvar 133 .wait_while(lock, |state| local_gen == state.generation_id) 134 .unwrap(); 135 BarrierWaitResult(false) 136 } else { 137 lock.count = 0; 138 lock.generation_id = lock.generation_id.wrapping_add(1); 139 self.cvar.notify_all(); 140 BarrierWaitResult(true) 141 } 142 } 143 } 144 145 impl fmt::Debug for BarrierWaitResult { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 147 f.debug_struct("BarrierWaitResult") 148 .field("is_leader", &self.is_leader()) 149 .finish() 150 } 151 } 152 153 impl BarrierWaitResult { 154 /// Returns `true` if this thread is the "leader thread" for the call to 155 /// [`Barrier::wait()`]. 156 /// 157 /// Only one thread will have `true` returned from their result, all other 158 /// threads will have `false` returned. 159 /// 160 /// # Examples 161 /// 162 /// ``` 163 /// use std::sync::Barrier; 164 /// 165 /// let barrier = Barrier::new(1); 166 /// let barrier_wait_result = barrier.wait(); 167 /// println!("{:?}", barrier_wait_result.is_leader()); 168 /// ``` 169 #[must_use] is_leader(&self) -> bool170 pub fn is_leader(&self) -> bool { 171 self.0 172 } 173 } 174