1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
//! Concurrency control for atomic swap of ownership.
//!
//! A common pattern for thread pools is that each thread owns a token,
//! and some times threads need to exchange tokens. A skeleton example
//! is:
//!
//! ```rust
//! # use std::sync::mpsc::{Sender, Receiver};
//! struct Token;
//! enum Message {
//! // Messages go here
//! };
//! struct Thread {
//! sender_to_other_thread: Sender<Message>,
//! receiver_from_other_thread: Receiver<Message>,
//! token: Token,
//! }
//! impl Thread {
//! fn swap_token(&mut self) {
//! // This function should swap the token with the other thread.
//! }
//! fn handle(&mut self, message: Message) {
//! match message {
//! // Message handlers go here
//! }
//! }
//! fn run(&mut self) {
//! loop {
//! let message = self.receiver_from_other_thread.recv();
//! match message {
//! Ok(message) => self.handle(message),
//! Err(_) => return,
//! }
//! }
//! }
//! }
//! ```
//!
//! To do this with the Rust channels, ownership of the token is first
//! passed from the thread to the channel, then to the other thead,
//! resulting in a transitory state where the thread does not have the
//! token. Typically to work round this, the thread stores an `Option<Token>`
//! rather than a `Token`:
//!
//! ```rust
//! # use std::sync::mpsc::{self, Sender, Receiver};
//! # use std::mem;
//! # struct Token;
//! enum Message {
//! SwapToken(Token, Sender<Token>),
//! };
//! struct Thread {
//! sender_to_other_thread: Sender<Message>,
//! receiver_from_other_thread: Receiver<Message>,
//! token: Option<Token>, // ANNOYING Option
//! }
//! impl Thread {
//! fn swap_token(&mut self) {
//! let (sender, receiver) = mpsc::channel();
//! let token = self.token.take().unwrap();
//! self.sender_to_other_thread.send(Message::SwapToken(token, sender));
//! let token = receiver.recv().unwrap();
//! self.token = Some(token);
//! }
//! fn handle(&mut self, message: Message) {
//! match message {
//! Message::SwapToken(token, sender) => {
//! let token = mem::replace(&mut self.token, Some(token)).unwrap();
//! sender.send(token).unwrap();
//! }
//! }
//! }
//! }
//! ```
//!
//! This crate provides a synchronization primitive for swapping ownership between threads.
//! The API is similar to channels, except that rather than separate `send(T)` and `recv():T`
//! methods, there is one `swap(T):T`, which swaps a `T` owned by one thread for a `T` owned
//! by the other. For example, it allows an implementation of the thread pool which always
//! owns a token.
//!
//! ```rust
//! # use std::sync::mpsc::{self, Sender, Receiver};
//! # use swapper::{self, Swapper};
//! # struct Token;
//! enum Message {
//! SwapToken(Swapper<Token>),
//! };
//! struct Thread {
//! sender_to_other_thread: Sender<Message>,
//! receiver_from_other_thread: Receiver<Message>,
//! token: Token,
//! }
//! impl Thread {
//! fn swap_token(&mut self) {
//! let (our_swapper, their_swapper) = swapper::swapper();
//! self.sender_to_other_thread.send(Message::SwapToken(their_swapper));
//! our_swapper.swap(&mut self.token).unwrap();
//! }
//! fn handle(&mut self, message: Message) {
//! match message {
//! Message::SwapToken(swapper) => swapper.swap(&mut self.token).unwrap(),
//! }
//! }
//! }
//! ```
use std::mem;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::AtomicPtr;
use std::sync::atomic::Ordering;
use std::sync::mpsc;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::RecvError;
use std::sync::mpsc::Sender;
use std::sync::mpsc::SendError;
/// A concurrency control for swapping ownership between threads.
pub struct Swapper<T> {
contents: Arc<AtomicPtr<T>>,
wait: Receiver<()>,
notify: Sender<()>,
}
impl<T: Send> Swapper<T> {
/// Swap data.
///
/// If the other half of the swap pair is blocked waiting to swap, then it swaps ownership
/// of the data, then unblocks the other thread. Otherwise it blocks waiting to swap.
pub fn swap(&self, our_ref: &mut T) -> Result<(), SwapError> {
loop {
// Is the other thead blocked waiting to swap? If so, swap and unblock it.
let their_ptr = self.contents.swap(ptr::null_mut(), Ordering::AcqRel);
if let Some(their_ref) = unsafe { their_ptr.as_mut() } {
// The safety of this implementation depends on the other thread being blocked
// while this swap happens.
mem::swap(our_ref, their_ref);
// We have swapped ownership, so its now safe to unblock the other thread.
try!(self.notify.send(()));
return Ok(());
}
// Is the other thead not ready for a swap yet? If so, block waiting to swap.
let their_ptr = self.contents.compare_and_swap(ptr::null_mut(), our_ref, Ordering::AcqRel);
if their_ptr.is_null() {
try!(self.wait.recv());
return Ok(());
}
}
}
}
// Be explicit about implementing Send.
unsafe impl<T: Send> Send for Swapper<T> {}
/// Create a new pair of swappers.
pub fn swapper<T>() -> (Swapper<T>, Swapper<T>) {
let contents = Arc::new(AtomicPtr::new(ptr::null_mut()));
let (notify_a, wait_a) = mpsc::channel();
let (notify_b, wait_b) = mpsc::channel();
let swapper_a = Swapper {
contents: contents.clone(),
notify: notify_b,
wait: wait_a,
};
let swapper_b = Swapper {
contents: contents,
notify: notify_a,
wait: wait_b,
};
(swapper_a, swapper_b)
}
/// The error returned when a thread attempts to swap with a thread that has dropped its swapper.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct SwapError(());
impl From<RecvError> for SwapError {
fn from(_: RecvError) -> SwapError {
SwapError(())
}
}
impl From<SendError<()>> for SwapError {
fn from(_: SendError<()>) -> SwapError {
SwapError(())
}
}