diff --git a/daemon/src/main.rs b/daemon/src/main.rs index 381e846..178b503 100644 --- a/daemon/src/main.rs +++ b/daemon/src/main.rs @@ -448,6 +448,7 @@ fn run_with_evdev( return Ok(()); } + let caps = is_caps_lock_on(&device); let key_state = device.get_key_state().ok(); let events = device.fetch_events()?; last_event_time = std::time::Instant::now(); @@ -530,7 +531,13 @@ fn run_with_evdev( if consumed_keys.contains(&keycode) { consumed_keys.remove(&keycode); } - if let Some(ch) = key_to_char(key) { + if let Some(mut ch) = key_to_char(key) { + let shift = is_modifier_held_shift(&key_state); + if ch.is_ascii_alphabetic() { + if shift ^ caps { + ch = ch.to_ascii_uppercase(); + } + } let commands = daemon.process_key(ch); if !commands.is_empty() { consumed_keys.insert(keycode); @@ -745,6 +752,22 @@ fn is_modifier_pressed(key_state: &Option>) -> b || key_state.contains(evdev::Key::KEY_RIGHTMETA) } +fn is_modifier_held_shift(key_state: &Option>) -> bool { + let ks = match key_state { + Some(ks) => ks, + None => return false, + }; + ks.contains(evdev::Key::KEY_LEFTSHIFT) || ks.contains(evdev::Key::KEY_RIGHTSHIFT) +} + +fn is_caps_lock_on(device: &evdev::Device) -> bool { + if let Ok(leds) = device.get_led_state() { + leds.contains(evdev::LedType::LED_CAPSL) + } else { + false + } +} + fn is_toggle_combination_state(key_state: &Option>, key: &str) -> bool { let key_state = match key_state { Some(ks) => ks, diff --git a/engine/src/engine.rs b/engine/src/engine.rs index 5cdd927..e7a9919 100644 --- a/engine/src/engine.rs +++ b/engine/src/engine.rs @@ -64,9 +64,16 @@ impl Engine { } pub fn flush(&mut self) -> Option { - match self.input_method { + let event = match self.input_method { InputMethod::Telex => self.telex.flush(), InputMethod::Vni => self.vni.flush(), + }; + if let Some(EngineEvent::Flush(word)) = event { + let cased = match_casing(&self.raw_buffer, &word); + self.raw_buffer.clear(); + Some(EngineEvent::Flush(cased)) + } else { + event } } @@ -95,15 +102,16 @@ impl Engine { let stripped = strip_diacritics(buffer); let backspaces = buffer.chars().count(); let had_tones = stripped != buffer; + let cased_stripped = match_casing(&self.raw_buffer, &stripped); self.reset(); if had_tones { Some(EngineEvent::UndoTones { backspaces, - restored: stripped, + restored: cased_stripped, }) } else { - Some(EngineEvent::Flush(stripped)) + Some(EngineEvent::Flush(cased_stripped)) } } @@ -137,15 +145,21 @@ impl Engine { return None; } - if ch == ' ' || ch == '\t' || ch == '.' || ch == ',' || ch == '!' || ch == '?' - || ch == ';' || ch == ':' || ch == '\n' + let lowercase_ch = if ch.is_ascii() { + ch.to_ascii_lowercase() + } else { + ch.to_lowercase().next().unwrap_or(ch) + }; + + if lowercase_ch == ' ' || lowercase_ch == '\t' || lowercase_ch == '.' || lowercase_ch == ',' || lowercase_ch == '!' || lowercase_ch == '?' + || lowercase_ch == ';' || lowercase_ch == ':' || lowercase_ch == '\n' { if self.raw_buffer.is_empty() { return None; } // Check for macro expansion before auto-restore - let macro_expansion = self.macros.get(&self.raw_buffer).cloned(); + let macro_expansion = self.macros.get(&self.raw_buffer.to_lowercase()).cloned(); if let Some(expansion) = macro_expansion { let previous_raw_len = self.raw_buffer.chars().count(); self.reset(); @@ -180,13 +194,14 @@ impl Engine { let previous_inner = self.buffer().to_string(); let previous_inner_len = previous_inner.chars().count(); + let previous_inner_cased = match_casing(&self.raw_buffer, &previous_inner); let flush_event = self.flush(); - let mut final_word = previous_inner.clone(); + let mut final_word = previous_inner_cased.clone(); if let Some(EngineEvent::Flush(word)) = flush_event { final_word = word; } - let result = if final_word != previous_inner { + let result = if final_word != previous_inner_cased { Some(EngineEvent::Replace { backspaces: previous_inner_len + 1, insert: format!("{}{}", final_word, ch), @@ -204,17 +219,18 @@ impl Engine { self.raw_buffer.push(ch); match self.input_method { - InputMethod::Telex => { self.telex.process_key(ch); } - InputMethod::Vni => { self.vni.process_key(ch); } + InputMethod::Telex => { self.telex.process_key(lowercase_ch); } + InputMethod::Vni => { self.vni.process_key(lowercase_ch); } } let new_inner = self.buffer().to_string(); - let expected_screen = format!("{}{}", previous_inner, ch); + let expected_screen = format!("{}{}", previous_inner, lowercase_ch); if new_inner != expected_screen { + let cased_inner = match_casing(&self.raw_buffer, &new_inner); Some(EngineEvent::Replace { backspaces: previous_inner.chars().count() + 1, - insert: new_inner, + insert: cased_inner, }) } else { None @@ -265,6 +281,32 @@ fn strip_diacritics(s: &str) -> String { .collect() } +fn match_casing(raw: &str, processed: &str) -> String { + if raw.is_empty() || processed.is_empty() { + return processed.to_string(); + } + + let alphabetic_chars: Vec = raw.chars().filter(|c| c.is_alphabetic()).collect(); + if alphabetic_chars.is_empty() { + return processed.to_string(); + } + + let all_upper = alphabetic_chars.iter().all(|c| c.is_uppercase()); + let first_upper = alphabetic_chars[0].is_uppercase(); + + if all_upper { + processed.to_uppercase() + } else if first_upper { + let mut chars = processed.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => processed.to_string(), + } + } else { + processed.to_string() + } +} + #[cfg(test)] mod tests { use super::*; @@ -318,4 +360,41 @@ mod tests { assert!(output.contains("không")); } + + #[test] + fn test_casing_preservation() { + let mut engine = Engine::new(InputMethod::Telex); + + // Lowercase: "sats" -> "sát" + engine.reset(); + let _ = engine.process_key('s'); + let _ = engine.process_key('a'); + let _ = engine.process_key('t'); + let _ = engine.process_key('s'); + assert_eq!(engine.buffer(), "sát"); + + // Titlecase: "Sats" -> "Sát" + engine.reset(); + engine.process_key('S'); + engine.process_key('a'); + engine.process_key('t'); + let event = engine.process_key('s'); + if let Some(EngineEvent::Replace { insert, .. }) = event { + assert_eq!(insert, "Sát"); + } else { + panic!("Expected Replace event, got {:?}", event); + } + + // Uppercase: "SATS" -> "SÁT" + engine.reset(); + engine.process_key('S'); + engine.process_key('A'); + engine.process_key('T'); + let event2 = engine.process_key('S'); + if let Some(EngineEvent::Replace { insert, .. }) = event2 { + assert_eq!(insert, "SÁT"); + } else { + panic!("Expected Replace event, got {:?}", event2); + } + } } diff --git a/protocol/src/uinput_monitor.rs b/protocol/src/uinput_monitor.rs index ba87b3f..e02dac2 100644 --- a/protocol/src/uinput_monitor.rs +++ b/protocol/src/uinput_monitor.rs @@ -171,13 +171,15 @@ impl UinputInjector { if let Ok(content) = std::fs::read_to_string("/proc/self/loginuid") { if let Ok(uid) = content.trim().parse::() { - unsafe { - let pw = libc::getpwuid(uid); - if !pw.is_null() { - let name = std::ffi::CStr::from_ptr((*pw).pw_name) - .to_string_lossy().into_owned(); - if !name.is_empty() { - return Some(name); + if uid != 4294967295 { + unsafe { + let pw = libc::getpwuid(uid); + if !pw.is_null() { + let name = std::ffi::CStr::from_ptr((*pw).pw_name) + .to_string_lossy().into_owned(); + if !name.is_empty() { + return Some(name); + } } } } @@ -196,34 +198,84 @@ impl UinputInjector { None } + /// Get original non-root UID and GID when running as root. + fn get_original_uid_gid() -> Option<(u32, u32)> { + let is_root = unsafe { libc::getuid() == 0 }; + if !is_root { + return None; + } + + let mut target_uid = None; + + if let Ok(uid_str) = std::env::var("SUDO_UID") { + if let Ok(uid) = uid_str.parse::() { + target_uid = Some(uid); + } + } + + if target_uid.is_none() { + if let Ok(uid_str) = std::env::var("PKEXEC_UID") { + if let Ok(uid) = uid_str.parse::() { + target_uid = Some(uid); + } + } + } + + if target_uid.is_none() { + if let Ok(content) = std::fs::read_to_string("/proc/self/loginuid") { + if let Ok(uid) = content.trim().parse::() { + if uid != 4294967295 { + target_uid = Some(uid); + } + } + } + } + + if let Some(uid) = target_uid { + unsafe { + let pw = libc::getpwuid(uid); + if !pw.is_null() { + let gid = (*pw).pw_gid; + return Some((uid, gid)); + } + } + } + + None + } + /// Run an external command as the original user if we're root. - /// Wayland tools (wtype, wl-copy) need the user's session, not root. - /// Uses explicit `env VAR=val` instead of `--preserve-env` for - /// compatibility with all sudo versions. + /// Uses native OS setuid/setgid to avoid slow PAM/logging/sudo startup overhead. fn run_as_user(program: &str, args: &[&str]) -> std::process::Output { let is_root = unsafe { libc::getuid() == 0 }; if is_root { - if let Some(original_user) = Self::get_original_username() { + if let Some((uid, gid)) = Self::get_original_uid_gid() { let wayland_display = std::env::var("WAYLAND_DISPLAY").unwrap_or_default(); let xdg_runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_default(); let display = std::env::var("DISPLAY").unwrap_or_default(); - let mut cmd = std::process::Command::new("sudo"); - cmd.args(["-u", &original_user, "env"]); + + use std::os::unix::process::CommandExt; + let mut cmd = std::process::Command::new(program); + cmd.uid(uid).gid(gid); + if !wayland_display.is_empty() { - cmd.arg(format!("WAYLAND_DISPLAY={}", wayland_display)); + cmd.env("WAYLAND_DISPLAY", wayland_display); } if !xdg_runtime_dir.is_empty() { - cmd.arg(format!("XDG_RUNTIME_DIR={}", xdg_runtime_dir)); + cmd.env("XDG_RUNTIME_DIR", xdg_runtime_dir); } if !display.is_empty() { - cmd.arg(format!("DISPLAY={}", display)); + cmd.env("DISPLAY", display); } - cmd.arg(program); + if let Some(username) = Self::get_original_username() { + cmd.env("HOME", format!("/home/{}", username)); + } + cmd.args(args); match cmd.output() { Ok(output) => return output, Err(e) => { - eprintln!("[vietc] Failed to run sudo -u {} env ... {} {}: {}", original_user, program, args.join(" "), e); + eprintln!("[vietc] Failed to run {} as uid={}: {}", program, uid, e); return std::process::Output { status: std::process::ExitStatus::default(), stdout: vec![], @@ -231,8 +283,6 @@ impl UinputInjector { }; } } - } else { - eprintln!("[vietc] Running as root but could not determine original user"); } } match std::process::Command::new(program).args(args).output() { @@ -271,13 +321,17 @@ impl UinputInjector { // It is Unicode. We must use a single unified channel. let is_wayland = std::env::var("WAYLAND_DISPLAY").is_ok(); + static HAS_WTYPE: std::sync::OnceLock = std::sync::OnceLock::new(); + static HAS_XDOTOOL: std::sync::OnceLock = std::sync::OnceLock::new(); + if is_wayland { - // Under Wayland, we try to use `wtype` for both backspaces and text. - let has_wtype = std::process::Command::new("which") - .arg("wtype") - .output() - .map(|o| o.status.success()) - .unwrap_or(false); + let has_wtype = *HAS_WTYPE.get_or_init(|| { + std::process::Command::new("which") + .arg("wtype") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + }); if has_wtype { let mut args = Vec::new(); @@ -295,12 +349,13 @@ impl UinputInjector { eprintln!("[vietc] wtype inject failed: {}", String::from_utf8_lossy(&output.stderr).trim()); } } else { - // Under X11, we try to use `xdotool` for both backspaces and text. - let has_xdotool = std::process::Command::new("which") - .arg("xdotool") - .output() - .map(|o| o.status.success()) - .unwrap_or(false); + let has_xdotool = *HAS_XDOTOOL.get_or_init(|| { + std::process::Command::new("which") + .arg("xdotool") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + }); if has_xdotool { let mut args = Vec::new();