zed/crates/gpui/src/executor.rs
Yara 🏳️‍⚧️ 1ac2b977fa
Make tasks inherit their callers priority (#46179)
Additionally this extracts more functionality into the RunnableVariant which is renamed to GpuiRunnable. 

Release Notes:

- N/A

---------

Co-authored-by: Lukas Wirth <lukas@zed.dev>
Co-authored-by: Cole Miller <cole@zed.dev>
2026-01-07 15:52:10 +00:00

1187 lines
39 KiB
Rust

use crate::{App, GpuiRunnable, PlatformDispatcher, RunnableMeta, TaskTiming, profiler};
use async_task::Runnable;
use futures::channel::mpsc;
use parking_lot::{Condvar, Mutex};
use smol::prelude::*;
use std::{
cell::Cell,
fmt::Debug,
marker::PhantomData,
mem::{self, ManuallyDrop},
num::NonZeroUsize,
panic::Location,
pin::Pin,
rc::Rc,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
task::{Context, Poll},
thread::{self, ThreadId},
time::{Duration, Instant},
};
use util::TryFutureExt as _;
use waker_fn::waker_fn;
#[cfg(any(test, feature = "test-support"))]
use rand::rngs::StdRng;
/// A pointer to the executor that is currently running,
/// for spawning background tasks.
#[derive(Clone)]
pub struct BackgroundExecutor {
#[doc(hidden)]
pub dispatcher: Arc<dyn PlatformDispatcher>,
}
/// A pointer to the executor that is currently running,
/// for spawning tasks on the main thread.
///
/// This is intentionally `!Send` via the `not_send` marker field. This is because
/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
/// only polled from the same thread it was spawned from. These checks would fail when spawning
/// foreground tasks from background threads.
#[derive(Clone)]
pub struct ForegroundExecutor {
#[doc(hidden)]
pub dispatcher: Arc<dyn PlatformDispatcher>,
liveness: std::sync::Weak<()>,
not_send: PhantomData<Rc<()>>,
}
/// Realtime task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum RealtimePriority {
/// Audio task
Audio,
/// Other realtime task
#[default]
Other,
}
/// Task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum Priority {
/// Realtime priority
///
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
Realtime(RealtimePriority),
/// High priority
///
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
High,
/// Medium priority, probably suits most of your use cases.
#[default]
Medium,
/// Low priority
///
/// Prioritize this for background work that can come in large quantities
/// to not starve the executor of resources for high priority tasks
Low,
}
thread_local! {
static CURRENT_TASKS_PRIORITY: Cell<Priority> = const { Cell::new(Priority::Medium) }; }
impl Priority {
/// Sets the priority any spawn call from the runnable about
/// to be run will use
pub(crate) fn set_as_default_for_spawns(&self) {
CURRENT_TASKS_PRIORITY.set(*self);
}
/// Returns the priority from the currently running task
pub fn inherit() -> Self {
CURRENT_TASKS_PRIORITY.get()
}
#[allow(dead_code)]
pub(crate) const fn probability(&self) -> u32 {
match self {
// realtime priorities are not considered for probability scheduling
Priority::Realtime(_) => 0,
Priority::High => 60,
Priority::Medium => 30,
Priority::Low => 10,
}
}
}
/// Task is a primitive that allows work to happen in the background.
///
/// It implements [`Future`] so you can `.await` on it.
///
/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
/// the task to continue running, but with no way to return a value.
#[must_use]
#[derive(Debug)]
pub struct Task<T>(TaskState<T>);
#[derive(Debug)]
enum TaskState<T> {
/// A task that is ready to return a value
Ready(Option<T>),
/// A task that is currently running.
Spawned(async_task::Task<T, RunnableMeta>),
}
impl<T> Task<T> {
/// Creates a new task that will resolve with the value
pub fn ready(val: T) -> Self {
Task(TaskState::Ready(Some(val)))
}
/// Detaching a task runs it to completion in the background
pub fn detach(self) {
match self {
Task(TaskState::Ready(_)) => {}
Task(TaskState::Spawned(task)) => task.detach(),
}
}
/// Converts this task into a fallible task that returns `Option<T>`.
///
/// Unlike the standard `Task<T>`, a [`FallibleTask`] will return `None`
/// if the app was dropped while the task is executing.
///
/// # Example
///
/// ```ignore
/// // Background task that gracefully handles app shutdown:
/// cx.background_spawn(async move {
/// let result = foreground_task.fallible().await;
/// if let Some(value) = result {
/// // Process the value
/// }
/// // If None, app was shut down - just exit gracefully
/// }).detach();
/// ```
pub fn fallible(self) -> FallibleTask<T> {
FallibleTask(match self.0 {
TaskState::Ready(val) => FallibleTaskState::Ready(val),
TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
})
}
}
impl<E, T> Task<Result<T, E>>
where
T: 'static,
E: 'static + Debug,
{
/// Run the task to completion in the background and log any
/// errors that occur.
#[track_caller]
pub fn detach_and_log_err(self, cx: &App) {
let location = core::panic::Location::caller();
cx.foreground_executor()
.spawn(self.log_tracked_err(*location))
.detach();
}
}
impl<T> Future for Task<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match unsafe { self.get_unchecked_mut() } {
Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
Task(TaskState::Spawned(task)) => task.poll(cx),
}
}
}
/// A task that returns `Option<T>` instead of panicking when cancelled.
#[must_use]
pub struct FallibleTask<T>(FallibleTaskState<T>);
enum FallibleTaskState<T> {
/// A task that is ready to return a value
Ready(Option<T>),
/// A task that is currently running (wraps async_task::FallibleTask).
Spawned(async_task::FallibleTask<T, RunnableMeta>),
}
impl<T> FallibleTask<T> {
/// Creates a new fallible task that will resolve with the value.
pub fn ready(val: T) -> Self {
FallibleTask(FallibleTaskState::Ready(Some(val)))
}
/// Detaching a task runs it to completion in the background.
pub fn detach(self) {
match self.0 {
FallibleTaskState::Ready(_) => {}
FallibleTaskState::Spawned(task) => task.detach(),
}
}
}
impl<T> Future for FallibleTask<T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match unsafe { self.get_unchecked_mut() } {
FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
}
}
}
impl<T> std::fmt::Debug for FallibleTask<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
FallibleTaskState::Spawned(task) => {
f.debug_tuple("FallibleTask::Spawned").field(task).finish()
}
}
}
}
/// A task label is an opaque identifier that you can use to
/// refer to a task in tests.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct TaskLabel(NonZeroUsize);
impl Default for TaskLabel {
fn default() -> Self {
Self::new()
}
}
impl TaskLabel {
/// Construct a new task label.
pub fn new() -> Self {
static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
Self(
NEXT_TASK_LABEL
.fetch_add(1, Ordering::SeqCst)
.try_into()
.unwrap(),
)
}
}
type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
/// BackgroundExecutor lets you run things on background threads.
/// In production this is a thread pool with no ordering guarantees.
/// In tests this is simulated by running tasks one by one in a deterministic
/// (but arbitrary) order controlled by the `SEED` environment variable.
impl BackgroundExecutor {
#[doc(hidden)]
pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
Self { dispatcher }
}
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
where
R: Send + 'static,
{
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None, priority)
}
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
///
/// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
/// completion before the current task is resumed, even if the current task is slated for cancellation.
pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
where
R: Send,
{
// We need to ensure that cancellation of the parent task does not drop the environment
// before the our own task has completed or got cancelled.
struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for NotifyOnDrop<'_> {
fn drop(&mut self) {
*self.0.1.lock() = true;
self.0.0.notify_all();
}
}
struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for WaitOnDrop<'_> {
fn drop(&mut self) {
let mut done = self.0.1.lock();
if !*done {
self.0.0.wait(&mut done);
}
}
}
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let pair = &(Condvar::new(), Mutex::new(false));
let _wait_guard = WaitOnDrop(pair);
let (runnable, task) = unsafe {
async_task::Builder::new()
.metadata(RunnableMeta {
location,
app: None,
priority: Priority::inherit(),
})
.spawn_unchecked(
move |_| async {
let _notify_guard = NotifyOnDrop(pair);
future.await
},
move |runnable| dispatcher.dispatch(GpuiRunnable::GpuiSpawned(runnable), None),
)
};
runnable.schedule();
task.await
}
/// Enqueues the given future to be run to completion on a background thread.
/// The given label can be used to control the priority of the task in tests.
#[track_caller]
pub fn spawn_labeled<R>(
&self,
label: TaskLabel,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
}
#[track_caller]
fn spawn_internal<R: Send + 'static>(
&self,
future: AnyFuture<R>,
label: Option<TaskLabel>,
priority: Priority,
) -> Task<R> {
let dispatcher = self.dispatcher.clone();
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
let location = core::panic::Location::caller();
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
dispatcher.spawn_realtime(
realtime,
Box::new(move || {
while let Ok(runnable) = rx.recv() {
let start = Instant::now();
let location = runnable.metadata().location;
let mut timing = TaskTiming {
location,
start,
end: None,
};
profiler::add_task_timing(timing);
Priority::Realtime(realtime).set_as_default_for_spawns();
runnable.run();
let end = Instant::now();
timing.end = Some(end);
profiler::add_task_timing(timing);
}
}),
);
async_task::Builder::new()
.metadata(RunnableMeta {
location,
priority,
app: None,
})
.spawn(
move |_| future,
move |runnable| {
let _ = tx.send(runnable);
},
)
} else {
let location = core::panic::Location::caller();
async_task::Builder::new()
.metadata(RunnableMeta {
location,
priority,
app: None,
})
.spawn(
move |_| future,
move |runnable| dispatcher.dispatch(GpuiRunnable::GpuiSpawned(runnable), label),
)
};
runnable.schedule();
Task(TaskState::Spawned(task))
}
/// Used by the test harness to run an async test in a synchronous fashion.
#[cfg(any(test, feature = "test-support"))]
#[track_caller]
pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
if let Ok(value) = self.block_internal(false, future, None) {
value
} else {
unreachable!()
}
}
/// Block the current thread until the given future resolves.
/// Consider using `block_with_timeout` instead.
pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
if let Ok(value) = self.block_internal(true, future, None) {
value
} else {
unreachable!()
}
}
#[cfg(not(any(test, feature = "test-support")))]
pub(crate) fn block_internal<Fut: Future>(
&self,
_background_only: bool,
future: Fut,
timeout: Option<Duration>,
) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
use std::time::Instant;
let mut future = Box::pin(future);
if timeout == Some(Duration::ZERO) {
return Err(future);
}
let deadline = timeout.map(|timeout| Instant::now() + timeout);
let parker = parking::Parker::new();
let unparker = parker.unparker();
let waker = waker_fn(move || {
unparker.unpark();
});
let mut cx = std::task::Context::from_waker(&waker);
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => return Ok(result),
Poll::Pending => {
let timeout =
deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
if let Some(timeout) = timeout {
if !parker.park_timeout(timeout)
&& deadline.is_some_and(|deadline| deadline < Instant::now())
{
return Err(future);
}
} else {
parker.park();
}
}
}
}
}
#[cfg(any(test, feature = "test-support"))]
#[track_caller]
pub(crate) fn block_internal<Fut: Future>(
&self,
background_only: bool,
future: Fut,
timeout: Option<Duration>,
) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
use std::sync::atomic::AtomicBool;
use std::time::Instant;
use parking::Parker;
let mut future = Box::pin(future);
if timeout == Some(Duration::ZERO) {
return Err(future);
}
// When using a real platform (e.g., MacPlatform for visual tests that need actual
// Metal rendering), there's no test dispatcher. In this case, we block the thread
// directly by polling the future and parking until woken. This is required for
// VisualTestAppContext which uses real platform rendering but still needs blocking
// behavior for code paths like editor initialization that call block_with_timeout.
let Some(dispatcher) = self.dispatcher.as_test() else {
let deadline = timeout.map(|timeout| Instant::now() + timeout);
let parker = Parker::new();
let unparker = parker.unparker();
let waker = waker_fn(move || {
unparker.unpark();
});
let mut cx = std::task::Context::from_waker(&waker);
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => return Ok(result),
Poll::Pending => {
let timeout = deadline
.map(|deadline| deadline.saturating_duration_since(Instant::now()));
if let Some(timeout) = timeout {
if !parker.park_timeout(timeout)
&& deadline.is_some_and(|deadline| deadline < Instant::now())
{
return Err(future);
}
} else {
parker.park();
}
}
}
}
};
let mut max_ticks = if timeout.is_some() {
dispatcher.gen_block_on_ticks()
} else {
usize::MAX
};
let parker = Parker::new();
let unparker = parker.unparker();
let awoken = Arc::new(AtomicBool::new(false));
let waker = waker_fn({
let awoken = awoken.clone();
let unparker = unparker.clone();
move || {
awoken.store(true, Ordering::SeqCst);
unparker.unpark();
}
});
let mut cx = std::task::Context::from_waker(&waker);
let duration = Duration::from_secs(
option_env!("GPUI_TEST_TIMEOUT")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(180),
);
let mut test_should_end_by = Instant::now() + duration;
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => return Ok(result),
Poll::Pending => {
if max_ticks == 0 {
return Err(future);
}
max_ticks -= 1;
if !dispatcher.tick(background_only) {
if awoken.swap(false, Ordering::SeqCst) {
continue;
}
if !dispatcher.parking_allowed() {
if dispatcher.advance_clock_to_next_delayed() {
continue;
}
let mut backtrace_message = String::new();
let mut waiting_message = String::new();
if let Some(backtrace) = dispatcher.waiting_backtrace() {
backtrace_message =
format!("\nbacktrace of waiting future:\n{:?}", backtrace);
}
if let Some(waiting_hint) = dispatcher.waiting_hint() {
waiting_message = format!("\n waiting on: {}\n", waiting_hint);
}
panic!(
"parked with nothing left to run{waiting_message}{backtrace_message}",
)
}
dispatcher.push_unparker(unparker.clone());
parker.park_timeout(Duration::from_millis(1));
if Instant::now() > test_should_end_by {
panic!("test timed out after {duration:?} with allow_parking")
}
}
}
}
}
}
/// Block the current thread until the given future resolves
/// or `duration` has elapsed.
pub fn block_with_timeout<Fut: Future>(
&self,
duration: Duration,
future: Fut,
) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
self.block_internal(true, future, Some(duration))
}
/// Scoped lets you start a number of tasks and waits
/// for all of them to complete before returning.
pub async fn scoped<'scope, F>(&self, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone(), Priority::default());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
/// Scoped lets you start a number of tasks and waits
/// for all of them to complete before returning.
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone(), priority);
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
/// Get the current time.
///
/// Calling this instead of `std::time::Instant::now` allows the use
/// of fake timers in tests.
pub fn now(&self) -> Instant {
self.dispatcher.now()
}
/// Returns a task that will complete after the given duration.
/// Depending on other concurrent tasks the elapsed duration may be longer
/// than requested.
pub fn timer(&self, duration: Duration) -> Task<()> {
if duration.is_zero() {
return Task::ready(());
}
let location = core::panic::Location::caller();
let (runnable, task) = async_task::Builder::new()
.metadata(RunnableMeta {
location,
priority: Priority::inherit(),
app: None,
})
.spawn(move |_| async move {}, {
let dispatcher = self.dispatcher.clone();
move |runnable| {
dispatcher.dispatch_after(duration, GpuiRunnable::GpuiSpawned(runnable))
}
});
runnable.schedule();
Task(TaskState::Spawned(task))
}
/// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
#[cfg(any(test, feature = "test-support"))]
pub fn start_waiting(&self) {
self.dispatcher.as_test().unwrap().start_waiting();
}
/// in tests, removes the debugging data added by start_waiting
#[cfg(any(test, feature = "test-support"))]
pub fn finish_waiting(&self) {
self.dispatcher.as_test().unwrap().finish_waiting();
}
/// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
#[cfg(any(test, feature = "test-support"))]
pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
self.dispatcher.as_test().unwrap().simulate_random_delay()
}
/// in tests, indicate that a given task from `spawn_labeled` should run after everything else
#[cfg(any(test, feature = "test-support"))]
pub fn deprioritize(&self, task_label: TaskLabel) {
self.dispatcher.as_test().unwrap().deprioritize(task_label)
}
/// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
#[cfg(any(test, feature = "test-support"))]
pub fn advance_clock(&self, duration: Duration) {
self.dispatcher.as_test().unwrap().advance_clock(duration)
}
/// in tests, run one task.
#[cfg(any(test, feature = "test-support"))]
pub fn tick(&self) -> bool {
self.dispatcher.as_test().unwrap().tick(false)
}
/// in tests, run all tasks that are ready to run. If after doing so
/// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
#[cfg(any(test, feature = "test-support"))]
pub fn run_until_parked(&self) {
self.dispatcher.as_test().unwrap().run_until_parked()
}
/// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
/// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
/// do take real async time to run.
#[cfg(any(test, feature = "test-support"))]
pub fn allow_parking(&self) {
self.dispatcher.as_test().unwrap().allow_parking();
}
/// undoes the effect of [`Self::allow_parking`].
#[cfg(any(test, feature = "test-support"))]
pub fn forbid_parking(&self) {
self.dispatcher.as_test().unwrap().forbid_parking();
}
/// adds detail to the "parked with nothing let to run" message.
#[cfg(any(test, feature = "test-support"))]
pub fn set_waiting_hint(&self, msg: Option<String>) {
self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
}
/// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
#[cfg(any(test, feature = "test-support"))]
pub fn rng(&self) -> StdRng {
self.dispatcher.as_test().unwrap().rng()
}
/// How many CPUs are available to the dispatcher.
pub fn num_cpus(&self) -> usize {
#[cfg(any(test, feature = "test-support"))]
return 4;
#[cfg(not(any(test, feature = "test-support")))]
return num_cpus::get();
}
/// Whether we're on the main thread.
pub fn is_main_thread(&self) -> bool {
self.dispatcher.is_main_thread()
}
#[cfg(any(test, feature = "test-support"))]
/// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
}
}
/// ForegroundExecutor runs things on the main thread.
impl ForegroundExecutor {
/// Creates a new ForegroundExecutor from the given PlatformDispatcher.
pub fn new(dispatcher: Arc<dyn PlatformDispatcher>, liveness: std::sync::Weak<()>) -> Self {
Self {
dispatcher,
liveness,
not_send: PhantomData,
}
}
/// Enqueues the given Task to run on the main thread at some point in the
/// future. This inherits the priority of the caller. Use
/// [`spawn_with_priority`](Self::spawn_with_priority) if you want to
/// overwrite that.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
where
R: 'static,
{
self.inner_spawn(self.liveness.clone(), Priority::default(), future)
}
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + 'static,
) -> Task<R>
where
R: 'static,
{
self.inner_spawn(self.liveness.clone(), priority, future)
}
#[track_caller]
pub(crate) fn inner_spawn<R>(
&self,
app: std::sync::Weak<()>,
priority: Priority,
future: impl Future<Output = R> + 'static,
) -> Task<R>
where
R: 'static,
{
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
#[track_caller]
fn inner<R: 'static>(
dispatcher: Arc<dyn PlatformDispatcher>,
future: AnyLocalFuture<R>,
location: &'static core::panic::Location<'static>,
app: std::sync::Weak<()>,
priority: Priority,
) -> Task<R> {
let (runnable, task) = spawn_local_with_source_location(
future,
move |runnable| {
dispatcher.dispatch_on_main_thread(GpuiRunnable::GpuiSpawned(runnable))
},
RunnableMeta {
location,
priority,
app: Some(app),
},
);
runnable.schedule();
Task(TaskState::Spawned(task))
}
inner::<R>(dispatcher, Box::pin(future), location, app, priority)
}
}
/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
///
/// Copy-modified from:
/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
#[track_caller]
fn spawn_local_with_source_location<Fut, S, M>(
future: Fut,
schedule: S,
metadata: M,
) -> (Runnable<M>, async_task::Task<Fut::Output, M>)
where
Fut: Future + 'static,
Fut::Output: 'static,
S: async_task::Schedule<M> + Send + Sync + 'static,
M: 'static,
{
#[inline]
fn thread_id() -> ThreadId {
std::thread_local! {
static ID: ThreadId = thread::current().id();
}
ID.try_with(|id| *id)
.unwrap_or_else(|_| thread::current().id())
}
struct Checked<F> {
id: ThreadId,
inner: ManuallyDrop<F>,
location: &'static Location<'static>,
}
impl<F> Drop for Checked<F> {
fn drop(&mut self) {
assert!(
self.id == thread_id(),
"local task dropped by a thread that didn't spawn it. Task spawned at {}",
self.location
);
unsafe { ManuallyDrop::drop(&mut self.inner) };
}
}
impl<F: Future> Future for Checked<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(
self.id == thread_id(),
"local task polled by a thread that didn't spawn it. Task spawned at {}",
self.location
);
unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
}
}
// Wrap the future into one that checks which thread it's on.
let future = Checked {
id: thread_id(),
inner: ManuallyDrop::new(future),
location: Location::caller(),
};
unsafe {
async_task::Builder::new()
.metadata(metadata)
.spawn_unchecked(move |_| future, schedule)
}
}
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
pub struct Scope<'a> {
executor: BackgroundExecutor,
priority: Priority,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
lifetime: PhantomData<&'a ()>,
}
impl<'a> Scope<'a> {
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
priority,
tx: Some(tx),
rx,
futures: Default::default(),
lifetime: PhantomData,
}
}
/// How many CPUs are available to the dispatcher.
pub fn num_cpus(&self) -> usize {
self.executor.num_cpus()
}
/// Spawn a future into this scope.
#[track_caller]
pub fn spawn<F>(&mut self, f: F)
where
F: Future<Output = ()> + Send + 'a,
{
let tx = self.tx.clone().unwrap();
// SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
// dropping this `Scope` blocks until all of the futures have resolved.
let f = unsafe {
mem::transmute::<
Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
>(Box::pin(async move {
f.await;
drop(tx);
}))
};
self.futures.push(f);
}
}
impl Drop for Scope<'_> {
fn drop(&mut self) {
self.tx.take().unwrap();
// Wait until the channel is closed, which means that all of the spawned
// futures have resolved.
self.executor.block(self.rx.next());
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{App, TestDispatcher, TestPlatform};
use rand::SeedableRng;
use std::cell::RefCell;
/// Helper to create test infrastructure.
/// Returns (dispatcher, background_executor, app) where app's foreground_executor has liveness.
fn create_test_app() -> (TestDispatcher, BackgroundExecutor, Rc<crate::AppCell>) {
let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
let arc_dispatcher = Arc::new(dispatcher.clone());
// Create liveness for task cancellation
let liveness = std::sync::Arc::new(());
let liveness_weak = std::sync::Arc::downgrade(&liveness);
let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
let foreground_executor = ForegroundExecutor::new(arc_dispatcher, liveness_weak);
let platform = TestPlatform::new(background_executor.clone(), foreground_executor);
let asset_source = Arc::new(());
let http_client = http_client::FakeHttpClient::with_404_response();
let app = App::new_app(platform, liveness, asset_source, http_client);
(dispatcher, background_executor, app)
}
#[test]
fn sanity_test_tasks_run() {
let (dispatcher, _background_executor, app) = create_test_app();
let foreground_executor = app.borrow().foreground_executor.clone();
let task_ran = Rc::new(RefCell::new(false));
foreground_executor
.spawn({
let task_ran = Rc::clone(&task_ran);
async move {
*task_ran.borrow_mut() = true;
}
})
.detach();
// Run dispatcher while app is still alive
dispatcher.run_until_parked();
// Task should have run
assert!(
*task_ran.borrow(),
"Task should run normally when app is alive"
);
}
#[test]
fn test_task_cancelled_when_app_dropped() {
let (dispatcher, _background_executor, app) = create_test_app();
let foreground_executor = app.borrow().foreground_executor.clone();
let app_weak = Rc::downgrade(&app);
let task_ran = Rc::new(RefCell::new(false));
let task_ran_clone = Rc::clone(&task_ran);
foreground_executor
.spawn(async move {
*task_ran_clone.borrow_mut() = true;
})
.detach();
drop(app);
assert!(app_weak.upgrade().is_none(), "App should have been dropped");
dispatcher.run_until_parked();
// The task should have been cancelled, not run
assert!(
!*task_ran.borrow(),
"Task should have been cancelled when app was dropped, but it ran!"
);
}
#[test]
fn test_nested_tasks_both_cancel() {
let (dispatcher, _background_executor, app) = create_test_app();
let foreground_executor = app.borrow().foreground_executor.clone();
let app_weak = Rc::downgrade(&app);
let outer_completed = Rc::new(RefCell::new(false));
let inner_completed = Rc::new(RefCell::new(false));
let reached_await = Rc::new(RefCell::new(false));
let outer_flag = Rc::clone(&outer_completed);
let inner_flag = Rc::clone(&inner_completed);
let await_flag = Rc::clone(&reached_await);
// Channel to block the inner task until we're ready
let (tx, rx) = futures::channel::oneshot::channel::<()>();
let inner_executor = foreground_executor.clone();
foreground_executor
.spawn(async move {
let inner_task = inner_executor.spawn({
let inner_flag = Rc::clone(&inner_flag);
async move {
rx.await.ok();
*inner_flag.borrow_mut() = true;
}
});
*await_flag.borrow_mut() = true;
inner_task.await;
*outer_flag.borrow_mut() = true;
})
.detach();
// Run dispatcher until outer task reaches the await point
// The inner task will be blocked on the channel
dispatcher.run_until_parked();
// Verify we actually reached the await point before dropping the app
assert!(
*reached_await.borrow(),
"Outer task should have reached the await point"
);
// Neither task should have completed yet
assert!(
!*outer_completed.borrow(),
"Outer task should not have completed yet"
);
assert!(
!*inner_completed.borrow(),
"Inner task should not have completed yet"
);
// Drop the channel sender and app while outer is awaiting inner
drop(tx);
drop(app);
assert!(app_weak.upgrade().is_none(), "App should have been dropped");
// Run dispatcher - both tasks should be cancelled
dispatcher.run_until_parked();
// Neither task should have completed (both were cancelled)
assert!(
!*outer_completed.borrow(),
"Outer task should have been cancelled, not completed"
);
assert!(
!*inner_completed.borrow(),
"Inner task should have been cancelled, not completed"
);
}
#[test]
#[should_panic]
fn test_polling_cancelled_task_panics() {
let (dispatcher, background_executor, app) = create_test_app();
let foreground_executor = app.borrow().foreground_executor.clone();
let app_weak = Rc::downgrade(&app);
let task = foreground_executor.spawn(async move { 42 });
drop(app);
assert!(app_weak.upgrade().is_none(), "App should have been dropped");
dispatcher.run_until_parked();
background_executor.block(task);
}
#[test]
fn test_polling_cancelled_task_returns_none_with_fallible() {
let (dispatcher, background_executor, app) = create_test_app();
let foreground_executor = app.borrow().foreground_executor.clone();
let app_weak = Rc::downgrade(&app);
let task = foreground_executor.spawn(async move { 42 }).fallible();
drop(app);
assert!(app_weak.upgrade().is_none(), "App should have been dropped");
dispatcher.run_until_parked();
let result = background_executor.block(task);
assert_eq!(result, None, "Cancelled task should return None");
}
}