mirror of
https://github.com/ZSeven-W/openpencil.git
synced 2026-05-31 19:04:29 +07:00
801 lines
25 KiB
Rust
801 lines
25 KiB
Rust
//! Control protocol implementation for bidirectional communication
|
|
//!
|
|
//! This module provides the protocol handler and message types for the control
|
|
//! protocol used in bidirectional communication with Claude Code CLI.
|
|
//!
|
|
//! # Overview
|
|
//!
|
|
//! The control protocol enables:
|
|
//! - Request/response communication
|
|
//! - Hook invocations from CLI to SDK
|
|
//! - Permission requests from CLI to SDK
|
|
//! - Protocol initialization and capability negotiation
|
|
//!
|
|
//! # Example: Basic Protocol Usage
|
|
//!
|
|
//! ```rust
|
|
//! use anthropic_agent_sdk::control::ProtocolHandler;
|
|
//!
|
|
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
|
//! let handler = ProtocolHandler::new();
|
|
//!
|
|
//! // Create an initialization request
|
|
//! let init_req = handler.create_init_request();
|
|
//! assert_eq!(init_req.protocol_version, "1.0");
|
|
//!
|
|
//! // After receiving init response, mark as initialized
|
|
//! handler.set_initialized(true);
|
|
//!
|
|
//! // Create control requests
|
|
//! let interrupt_req = handler.create_interrupt_request();
|
|
//! let msg_req = handler.create_send_message_request("Hello!".to_string());
|
|
//! # Ok(())
|
|
//! # }
|
|
//! ```
|
|
//!
|
|
//! # Example: Handling Hook Events
|
|
//!
|
|
//! ```rust
|
|
//! use anthropic_agent_sdk::control::ProtocolHandler;
|
|
//! use tokio::sync::mpsc;
|
|
//!
|
|
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
|
//! let mut handler = ProtocolHandler::new();
|
|
//!
|
|
//! // Set up hook channel
|
|
//! let (hook_tx, mut hook_rx) = mpsc::unbounded_channel();
|
|
//! handler.set_hook_channel(hook_tx);
|
|
//!
|
|
//! // When a hook event arrives, it will be sent to hook_rx
|
|
//! // You can then process it and send a response
|
|
//! tokio::spawn(async move {
|
|
//! while let Some((hook_id, event)) = hook_rx.recv().await {
|
|
//! println!("Received hook: {} {:?}", hook_id, event);
|
|
//! // Process hook and create response...
|
|
//! }
|
|
//! });
|
|
//! # Ok(())
|
|
//! # }
|
|
//! ```
|
|
//!
|
|
//! # Example: Serialization
|
|
//!
|
|
//! ```rust
|
|
//! use anthropic_agent_sdk::control::{ControlMessage, ProtocolHandler};
|
|
//!
|
|
//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
|
|
//! let handler = ProtocolHandler::new();
|
|
//! let request = handler.create_interrupt_request();
|
|
//! let message = ControlMessage::Request(request);
|
|
//!
|
|
//! // Serialize to JSON
|
|
//! let json = handler.serialize_message(&message)?;
|
|
//! assert!(json.ends_with('\n'));
|
|
//!
|
|
//! // Deserialize from JSON
|
|
//! let parsed = handler.deserialize_message(json.trim())?;
|
|
//! # Ok(())
|
|
//! # }
|
|
//! ```
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
|
use tokio::sync::{Mutex, mpsc, oneshot};
|
|
|
|
use crate::error::{ClaudeError, Result};
|
|
use crate::types::{HookEvent, PermissionRequest, PermissionResult, RequestId};
|
|
|
|
/// Control message envelope for all protocol messages
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum ControlMessage {
|
|
/// Request from SDK to CLI
|
|
#[serde(rename = "request")]
|
|
Request(ControlRequest),
|
|
/// Response from CLI to SDK
|
|
#[serde(rename = "response")]
|
|
Response(ControlResponse),
|
|
/// Initialization request
|
|
#[serde(rename = "init")]
|
|
Init(InitRequest),
|
|
/// Initialization response
|
|
#[serde(rename = "init_response")]
|
|
InitResponse(InitResponse),
|
|
}
|
|
|
|
/// Request from SDK to CLI
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "method", content = "params")]
|
|
pub enum ControlRequest {
|
|
/// Interrupt the current operation
|
|
#[serde(rename = "interrupt")]
|
|
Interrupt {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
},
|
|
/// Send a message to Claude
|
|
#[serde(rename = "send_message")]
|
|
SendMessage {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Message content to send
|
|
content: String,
|
|
},
|
|
/// Respond to a hook invocation
|
|
#[serde(rename = "hook_response")]
|
|
HookResponse {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Hook event ID being responded to
|
|
hook_id: String,
|
|
/// Hook response data
|
|
response: serde_json::Value,
|
|
},
|
|
/// Respond to a permission request
|
|
#[serde(rename = "permission_response")]
|
|
PermissionResponse {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Permission request ID being responded to
|
|
request_id: RequestId,
|
|
/// Permission result (Allow/Deny)
|
|
result: PermissionResult,
|
|
},
|
|
/// Set the model for subsequent messages
|
|
#[serde(rename = "set_model")]
|
|
SetModel {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Model name or alias (e.g., "haiku", "sonnet", "opus", or full model ID)
|
|
model: Option<String>,
|
|
},
|
|
/// Set the permission mode
|
|
#[serde(rename = "set_permission_mode")]
|
|
SetPermissionMode {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Permission mode to use
|
|
mode: String,
|
|
},
|
|
/// Set the maximum thinking tokens
|
|
#[serde(rename = "set_max_thinking_tokens")]
|
|
SetMaxThinkingTokens {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// Maximum thinking tokens (null to disable)
|
|
max_thinking_tokens: Option<u32>,
|
|
},
|
|
/// Rewind files to a checkpoint
|
|
///
|
|
/// Restores files to their state at the specified user message.
|
|
/// Requires `enable_file_checkpointing: true` in options.
|
|
#[serde(rename = "rewind_files")]
|
|
RewindFiles {
|
|
/// Unique request identifier
|
|
id: RequestId,
|
|
/// UUID of the user message checkpoint to rewind to
|
|
user_message_uuid: String,
|
|
},
|
|
}
|
|
|
|
/// Response from CLI to SDK
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "status")]
|
|
pub enum ControlResponse {
|
|
/// Successful response
|
|
#[serde(rename = "success")]
|
|
Success {
|
|
/// Request ID this responds to
|
|
id: RequestId,
|
|
/// Optional response data
|
|
data: Option<serde_json::Value>,
|
|
},
|
|
/// Error response
|
|
#[serde(rename = "error")]
|
|
Error {
|
|
/// Request ID this responds to
|
|
id: RequestId,
|
|
/// Error message
|
|
message: String,
|
|
/// Error code
|
|
code: Option<String>,
|
|
},
|
|
/// Hook invocation from CLI
|
|
#[serde(rename = "hook")]
|
|
Hook {
|
|
/// Hook invocation ID
|
|
id: String,
|
|
/// Hook event details
|
|
event: HookEvent,
|
|
},
|
|
/// Permission request from CLI
|
|
#[serde(rename = "permission")]
|
|
Permission {
|
|
/// Permission request ID
|
|
id: RequestId,
|
|
/// Permission request details
|
|
request: PermissionRequest,
|
|
},
|
|
}
|
|
|
|
/// Initialization request sent from SDK to CLI
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct InitRequest {
|
|
/// Protocol version
|
|
pub protocol_version: String,
|
|
/// SDK version
|
|
pub sdk_version: String,
|
|
/// Client capabilities
|
|
pub capabilities: ClientCapabilities,
|
|
}
|
|
|
|
/// Client capabilities for negotiation
|
|
#[allow(clippy::struct_excessive_bools)]
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ClientCapabilities {
|
|
/// Supports bidirectional communication
|
|
pub bidirectional: bool,
|
|
/// Supports hooks
|
|
pub hooks: bool,
|
|
/// Supports permissions
|
|
pub permissions: bool,
|
|
/// Supports interrupts
|
|
pub interrupts: bool,
|
|
}
|
|
|
|
/// Initialization response from CLI to SDK
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct InitResponse {
|
|
/// Protocol version accepted
|
|
pub protocol_version: String,
|
|
/// CLI version
|
|
pub cli_version: String,
|
|
/// Server capabilities
|
|
pub capabilities: ServerCapabilities,
|
|
/// Session ID for this connection
|
|
pub session_id: String,
|
|
}
|
|
|
|
/// Server capabilities advertised by CLI
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ServerCapabilities {
|
|
/// Supports streaming responses
|
|
pub streaming: bool,
|
|
/// Supports tool use
|
|
pub tools: bool,
|
|
/// Supports MCP servers
|
|
pub mcp: bool,
|
|
}
|
|
|
|
/// Pending request awaiting response
|
|
struct PendingRequest {
|
|
/// Response channel
|
|
response_tx: oneshot::Sender<ControlResponse>,
|
|
}
|
|
|
|
/// Protocol handler for managing control protocol communication
|
|
pub struct ProtocolHandler {
|
|
/// Request ID counter
|
|
next_request_id: Arc<AtomicU64>,
|
|
/// Pending requests awaiting responses
|
|
pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
|
|
/// Initialized flag
|
|
initialized: Arc<AtomicBool>,
|
|
/// Hook callback channel
|
|
hook_tx: Option<mpsc::UnboundedSender<(String, HookEvent)>>,
|
|
/// Permission callback channel
|
|
permission_tx: Option<mpsc::UnboundedSender<(RequestId, PermissionRequest)>>,
|
|
}
|
|
|
|
impl ProtocolHandler {
|
|
/// Create a new protocol handler
|
|
#[must_use]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
next_request_id: Arc::new(AtomicU64::new(1)),
|
|
pending_requests: Arc::new(Mutex::new(HashMap::new())),
|
|
initialized: Arc::new(AtomicBool::new(false)),
|
|
hook_tx: None,
|
|
permission_tx: None,
|
|
}
|
|
}
|
|
|
|
/// Set hook callback channel
|
|
pub fn set_hook_channel(&mut self, tx: mpsc::UnboundedSender<(String, HookEvent)>) {
|
|
self.hook_tx = Some(tx);
|
|
}
|
|
|
|
/// Set permission callback channel
|
|
pub fn set_permission_channel(
|
|
&mut self,
|
|
tx: mpsc::UnboundedSender<(RequestId, PermissionRequest)>,
|
|
) {
|
|
self.permission_tx = Some(tx);
|
|
}
|
|
|
|
/// Check if protocol is initialized
|
|
#[must_use]
|
|
pub fn is_initialized(&self) -> bool {
|
|
self.initialized.load(Ordering::SeqCst)
|
|
}
|
|
|
|
/// Set protocol as initialized (for cases where no handshake is needed)
|
|
pub fn set_initialized(&self, value: bool) {
|
|
self.initialized.store(value, Ordering::SeqCst);
|
|
}
|
|
|
|
/// Generate next request ID
|
|
fn next_id(&self) -> RequestId {
|
|
let id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
|
|
RequestId::new(format!("req-{id}"))
|
|
}
|
|
|
|
/// Create initialization request
|
|
#[must_use]
|
|
pub fn create_init_request(&self) -> InitRequest {
|
|
InitRequest {
|
|
protocol_version: "1.0".to_string(),
|
|
sdk_version: crate::VERSION.to_string(),
|
|
capabilities: ClientCapabilities {
|
|
bidirectional: true,
|
|
hooks: true,
|
|
permissions: true,
|
|
interrupts: true,
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Handle initialization response
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `ProtocolError` if protocol version is unsupported.
|
|
pub fn handle_init_response(&self, response: &InitResponse) -> Result<()> {
|
|
// Validate protocol version
|
|
if response.protocol_version != "1.0" {
|
|
return Err(ClaudeError::protocol_error(format!(
|
|
"Unsupported protocol version: {}",
|
|
response.protocol_version
|
|
)));
|
|
}
|
|
|
|
self.initialized.store(true, Ordering::SeqCst);
|
|
Ok(())
|
|
}
|
|
|
|
/// Send a request and wait for response
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `ProtocolError` if protocol is not initialized.
|
|
pub async fn send_request(
|
|
&self,
|
|
request: ControlRequest,
|
|
) -> Result<oneshot::Receiver<ControlResponse>> {
|
|
if !self.is_initialized() {
|
|
return Err(ClaudeError::protocol_error(
|
|
"Protocol not initialized - call init first",
|
|
));
|
|
}
|
|
|
|
let id = Self::get_request_id(&request);
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
|
|
let pending = PendingRequest { response_tx };
|
|
|
|
{
|
|
let mut pending_requests = self.pending_requests.lock().await;
|
|
pending_requests.insert(id, pending);
|
|
}
|
|
|
|
Ok(response_rx)
|
|
}
|
|
|
|
/// Extract request ID from a control request
|
|
fn get_request_id(request: &ControlRequest) -> RequestId {
|
|
match request {
|
|
ControlRequest::Interrupt { id }
|
|
| ControlRequest::SendMessage { id, .. }
|
|
| ControlRequest::HookResponse { id, .. }
|
|
| ControlRequest::PermissionResponse { id, .. }
|
|
| ControlRequest::SetModel { id, .. }
|
|
| ControlRequest::SetPermissionMode { id, .. }
|
|
| ControlRequest::SetMaxThinkingTokens { id, .. }
|
|
| ControlRequest::RewindFiles { id, .. } => id.clone(),
|
|
}
|
|
}
|
|
|
|
/// Handle incoming control response
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `ProtocolError` if hook or permission channel is closed.
|
|
pub async fn handle_response(&self, response: ControlResponse) -> Result<()> {
|
|
match &response {
|
|
ControlResponse::Success { id, .. } | ControlResponse::Error { id, .. } => {
|
|
let mut pending_requests = self.pending_requests.lock().await;
|
|
if let Some(pending) = pending_requests.remove(id) {
|
|
let _ = pending.response_tx.send(response);
|
|
}
|
|
Ok(())
|
|
}
|
|
ControlResponse::Hook { id, event } => {
|
|
if let Some(ref tx) = self.hook_tx {
|
|
tx.send((id.clone(), *event))
|
|
.map_err(|_| ClaudeError::protocol_error("Hook channel closed"))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
ControlResponse::Permission { id, request } => {
|
|
if let Some(ref tx) = self.permission_tx {
|
|
tx.send((id.clone(), request.clone()))
|
|
.map_err(|_| ClaudeError::protocol_error("Permission channel closed"))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Create interrupt request
|
|
#[must_use]
|
|
pub fn create_interrupt_request(&self) -> ControlRequest {
|
|
ControlRequest::Interrupt { id: self.next_id() }
|
|
}
|
|
|
|
/// Create send message request
|
|
#[must_use]
|
|
pub fn create_send_message_request(&self, content: impl Into<String>) -> ControlRequest {
|
|
ControlRequest::SendMessage {
|
|
id: self.next_id(),
|
|
content: content.into(),
|
|
}
|
|
}
|
|
|
|
/// Create hook response
|
|
#[must_use]
|
|
pub fn create_hook_response(
|
|
&self,
|
|
hook_id: impl Into<String>,
|
|
response: serde_json::Value,
|
|
) -> ControlRequest {
|
|
ControlRequest::HookResponse {
|
|
id: self.next_id(),
|
|
hook_id: hook_id.into(),
|
|
response,
|
|
}
|
|
}
|
|
|
|
/// Create permission response
|
|
#[must_use]
|
|
pub fn create_permission_response(
|
|
&self,
|
|
request_id: RequestId,
|
|
result: PermissionResult,
|
|
) -> ControlRequest {
|
|
ControlRequest::PermissionResponse {
|
|
id: self.next_id(),
|
|
request_id,
|
|
result,
|
|
}
|
|
}
|
|
|
|
/// Create set model request
|
|
#[must_use]
|
|
pub fn create_set_model_request(&self, model: Option<String>) -> ControlRequest {
|
|
ControlRequest::SetModel {
|
|
id: self.next_id(),
|
|
model,
|
|
}
|
|
}
|
|
|
|
/// Create set permission mode request
|
|
#[must_use]
|
|
pub fn create_set_permission_mode_request(&self, mode: impl Into<String>) -> ControlRequest {
|
|
ControlRequest::SetPermissionMode {
|
|
id: self.next_id(),
|
|
mode: mode.into(),
|
|
}
|
|
}
|
|
|
|
/// Create set max thinking tokens request
|
|
#[must_use]
|
|
pub fn create_set_max_thinking_tokens_request(
|
|
&self,
|
|
max_thinking_tokens: Option<u32>,
|
|
) -> ControlRequest {
|
|
ControlRequest::SetMaxThinkingTokens {
|
|
id: self.next_id(),
|
|
max_thinking_tokens,
|
|
}
|
|
}
|
|
|
|
/// Create a rewind files request
|
|
///
|
|
/// Rewinds files to their state at the specified checkpoint.
|
|
/// The `user_message_uuid` comes from the `uuid` field on User messages
|
|
/// when `enable_file_checkpointing` is enabled.
|
|
#[must_use]
|
|
pub fn create_rewind_files_request(
|
|
&self,
|
|
user_message_uuid: impl Into<String>,
|
|
) -> ControlRequest {
|
|
ControlRequest::RewindFiles {
|
|
id: self.next_id(),
|
|
user_message_uuid: user_message_uuid.into(),
|
|
}
|
|
}
|
|
|
|
/// Serialize control message to JSON
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `JsonEncode` if serialization fails.
|
|
pub fn serialize_message(&self, message: &ControlMessage) -> Result<String> {
|
|
serde_json::to_string(message)
|
|
.map(|s| format!("{s}\n"))
|
|
.map_err(|e| ClaudeError::json_encode(format!("Failed to serialize message: {e}")))
|
|
}
|
|
|
|
/// Deserialize control message from JSON
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns `JsonDecodeMsg` if parsing fails.
|
|
pub fn deserialize_message(&self, json: &str) -> Result<ControlMessage> {
|
|
serde_json::from_str(json).map_err(|e| {
|
|
ClaudeError::json_decode_msg(format!("Failed to deserialize message: {e}"))
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Default for ProtocolHandler {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::types::ToolName;
|
|
|
|
#[test]
|
|
fn test_request_id_generation() {
|
|
let handler = ProtocolHandler::new();
|
|
let id1 = handler.next_id();
|
|
let id2 = handler.next_id();
|
|
assert_ne!(id1, id2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_init_request_creation() {
|
|
let handler = ProtocolHandler::new();
|
|
let init_req = handler.create_init_request();
|
|
assert_eq!(init_req.protocol_version, "1.0");
|
|
assert!(init_req.capabilities.bidirectional);
|
|
}
|
|
|
|
#[test]
|
|
fn test_serialize_deserialize() {
|
|
let handler = ProtocolHandler::new();
|
|
let request = handler.create_interrupt_request();
|
|
let message = ControlMessage::Request(request);
|
|
|
|
let serialized = handler.serialize_message(&message).unwrap();
|
|
let deserialized = handler.deserialize_message(serialized.trim()).unwrap();
|
|
|
|
match deserialized {
|
|
ControlMessage::Request(ControlRequest::Interrupt { .. }) => {}
|
|
_ => panic!("Wrong message type"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_deserialize_invalid_json() {
|
|
let handler = ProtocolHandler::new();
|
|
let result = handler.deserialize_message("not valid json");
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_deserialize_invalid_message_structure() {
|
|
let handler = ProtocolHandler::new();
|
|
let invalid = r#"{"type":"unknown_type"}"#;
|
|
let result = handler.deserialize_message(invalid);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_deserialize_missing_fields() {
|
|
let handler = ProtocolHandler::new();
|
|
let missing = r#"{"type":"request"}"#;
|
|
let result = handler.deserialize_message(missing);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_handle_response_with_missing_pending_request() {
|
|
let handler = ProtocolHandler::new();
|
|
handler.set_initialized(true);
|
|
|
|
// Create a response for a request that was never sent
|
|
let response = ControlResponse::Success {
|
|
id: RequestId::new("non-existent-req"),
|
|
data: None,
|
|
};
|
|
|
|
// Should not error, just ignore
|
|
let result = handler.handle_response(response).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_hook_response_without_channel() {
|
|
let handler = ProtocolHandler::new();
|
|
|
|
// Try to handle hook response without setting up channel
|
|
let response = ControlResponse::Hook {
|
|
id: "hook-1".to_string(),
|
|
event: HookEvent::PreToolUse,
|
|
};
|
|
|
|
// Should not error, just no-op
|
|
let result = handler.handle_response(response).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_permission_response_without_channel() {
|
|
let handler = ProtocolHandler::new();
|
|
|
|
// Try to handle permission response without setting up channel
|
|
let response = ControlResponse::Permission {
|
|
id: RequestId::new("perm-1"),
|
|
request: PermissionRequest {
|
|
tool_name: ToolName::new("test"),
|
|
tool_input: serde_json::json!({}),
|
|
context: crate::types::ToolPermissionContext::new(vec![]),
|
|
},
|
|
};
|
|
|
|
// Should not error, just no-op
|
|
let result = handler.handle_response(response).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_init_response_with_wrong_version() {
|
|
let handler = ProtocolHandler::new();
|
|
|
|
let init_response = InitResponse {
|
|
protocol_version: "999.0".to_string(),
|
|
cli_version: "1.0.0".to_string(),
|
|
capabilities: ServerCapabilities {
|
|
streaming: true,
|
|
tools: true,
|
|
mcp: true,
|
|
},
|
|
session_id: "test".to_string(),
|
|
};
|
|
|
|
let result = handler.handle_init_response(&init_response);
|
|
assert!(result.is_err());
|
|
assert!(!handler.is_initialized());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_send_request_without_init() {
|
|
let handler = ProtocolHandler::new();
|
|
assert!(!handler.is_initialized());
|
|
|
|
let request = handler.create_interrupt_request();
|
|
let result = handler.send_request(request).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_serialize_all_request_types() {
|
|
let handler = ProtocolHandler::new();
|
|
|
|
// Test Interrupt
|
|
let req = handler.create_interrupt_request();
|
|
let msg = ControlMessage::Request(req);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test SendMessage
|
|
let req = handler.create_send_message_request("test".to_string());
|
|
let msg = ControlMessage::Request(req);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test HookResponse
|
|
let req = handler.create_hook_response("hook-1".to_string(), serde_json::json!({}));
|
|
let msg = ControlMessage::Request(req);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test PermissionResponse
|
|
let req = handler.create_permission_response(
|
|
RequestId::new("req-1"),
|
|
crate::types::PermissionResult::Allow(crate::types::PermissionResultAllow {
|
|
updated_input: None,
|
|
updated_permissions: None,
|
|
}),
|
|
);
|
|
let msg = ControlMessage::Request(req);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_serialize_all_response_types() {
|
|
let handler = ProtocolHandler::new();
|
|
|
|
// Test Success
|
|
let resp = ControlResponse::Success {
|
|
id: RequestId::new("req-1"),
|
|
data: Some(serde_json::json!({"result": "ok"})),
|
|
};
|
|
let msg = ControlMessage::Response(resp);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test Error
|
|
let resp = ControlResponse::Error {
|
|
id: RequestId::new("req-1"),
|
|
message: "test error".to_string(),
|
|
code: Some("ERR_TEST".to_string()),
|
|
};
|
|
let msg = ControlMessage::Response(resp);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test Hook
|
|
let resp = ControlResponse::Hook {
|
|
id: "hook-1".to_string(),
|
|
event: HookEvent::PreToolUse,
|
|
};
|
|
let msg = ControlMessage::Response(resp);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
|
|
// Test Permission
|
|
let resp = ControlResponse::Permission {
|
|
id: RequestId::new("perm-1"),
|
|
request: PermissionRequest {
|
|
tool_name: ToolName::new("test"),
|
|
tool_input: serde_json::json!({}),
|
|
context: crate::types::ToolPermissionContext::new(vec![]),
|
|
},
|
|
};
|
|
let msg = ControlMessage::Response(resp);
|
|
assert!(handler.serialize_message(&msg).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_request_id() {
|
|
let interrupt = ControlRequest::Interrupt {
|
|
id: RequestId::new("id1"),
|
|
};
|
|
assert_eq!(ProtocolHandler::get_request_id(&interrupt).as_str(), "id1");
|
|
|
|
let send_msg = ControlRequest::SendMessage {
|
|
id: RequestId::new("id2"),
|
|
content: "test".to_string(),
|
|
};
|
|
assert_eq!(ProtocolHandler::get_request_id(&send_msg).as_str(), "id2");
|
|
|
|
let hook_resp = ControlRequest::HookResponse {
|
|
id: RequestId::new("id3"),
|
|
hook_id: "hook".to_string(),
|
|
response: serde_json::json!({}),
|
|
};
|
|
assert_eq!(ProtocolHandler::get_request_id(&hook_resp).as_str(), "id3");
|
|
|
|
let perm_resp = ControlRequest::PermissionResponse {
|
|
id: RequestId::new("id4"),
|
|
request_id: RequestId::new("perm"),
|
|
result: crate::types::PermissionResult::Allow(crate::types::PermissionResultAllow {
|
|
updated_input: None,
|
|
updated_permissions: None,
|
|
}),
|
|
};
|
|
assert_eq!(ProtocolHandler::get_request_id(&perm_resp).as_str(), "id4");
|
|
}
|
|
}
|