edge_impulse_runner/inference/
model.rs

1use rand::{Rng, thread_rng};
2use std::collections::{HashMap, VecDeque};
3use std::fmt;
4use std::io::{BufRead, BufReader, Write};
5use std::os::unix::net::UnixStream;
6use std::path::Path;
7use std::process::Child;
8use std::sync::atomic::{AtomicU32, Ordering};
9use std::time::{Duration, Instant};
10use tempfile::{TempDir, tempdir};
11
12use crate::error::EimError;
13use crate::inference::messages::{
14    ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
15    SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
16};
17use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
18
19/// Debug callback type for receiving debug messages
20pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
21
22/// Edge Impulse Model Runner for Rust
23///
24/// This module provides functionality for running Edge Impulse machine learning models on Linux systems.
25/// It handles model lifecycle management, communication, and inference operations.
26///
27/// # Key Components
28///
29/// - `EimModel`: Main struct for managing Edge Impulse models
30/// - `SensorType`: Enum representing supported sensor input types
31/// - `ContinuousState`: Internal state management for continuous inference mode
32/// - `MovingAverageFilter`: Smoothing filter for continuous inference results
33///
34/// # Features
35///
36/// - Model process management and Unix socket communication
37/// - Support for both single-shot and continuous inference modes
38/// - Debug logging and callback system
39/// - Moving average filtering for continuous mode results
40/// - Automatic retry mechanisms for socket connections
41/// - Visual anomaly detection (FOMO AD) support with normalized scores
42///
43/// # Example Usage
44///
45/// ```no_run
46/// use edge_impulse_runner::{EimModel, InferenceResult};
47///
48/// // Create a new model instance
49/// let mut model = EimModel::new("path/to/model.eim").unwrap();
50///
51/// // For RGB images, features must be in 0xRRGGBB format (not normalized to [0,1])
52/// let features = vec![0xFF0000 as f32, 0x00FF00 as f32, 0x0000FF as f32]; // Example: Red, Green, Blue pixels
53/// let result = model.infer(features, None).unwrap();
54///
55/// // For visual anomaly detection models, normalize the results
56/// if let InferenceResult::VisualAnomaly { anomaly, visual_anomaly_max, visual_anomaly_mean, visual_anomaly_grid } = result.result {
57///     let (normalized_anomaly, normalized_max, normalized_mean, normalized_regions) =
58///         model.normalize_visual_anomaly(
59///             anomaly,
60///             visual_anomaly_max,
61///             visual_anomaly_mean,
62///             &visual_anomaly_grid.iter()
63///                 .map(|bbox| (bbox.value, bbox.x as u32, bbox.y as u32, bbox.width as u32, bbox.height as u32))
64///                 .collect::<Vec<_>>()
65///         );
66///     println!("Anomaly score: {:.2}%", normalized_anomaly * 100.0);
67/// }
68/// ```
69///
70/// # Communication Protocol
71///
72/// The model communicates with the Edge Impulse process using JSON messages over Unix sockets:
73/// 1. Hello message for initialization
74/// 2. Model info response
75/// 3. Classification requests
76/// 4. Inference responses
77///
78/// # Error Handling
79///
80/// The module uses a custom `EimError` type for error handling, covering:
81/// - Invalid file paths
82/// - Socket communication errors
83/// - Model execution errors
84/// - JSON serialization/deserialization errors
85///
86/// # Visual Anomaly Detection
87///
88/// For visual anomaly detection models (FOMO AD):
89/// - Scores are normalized relative to the model's minimum anomaly threshold
90/// - Results include overall anomaly score, maximum score, mean score, and anomalous regions
91/// - Region coordinates are provided in the original image dimensions
92/// - All scores are clamped to [0,1] range and displayed as percentages
93/// - Debug mode provides detailed information about thresholds and regions
94///
95/// # Image Feature Format
96///
97/// For RGB images, features must be in 0xRRGGBB format (not normalized to [0,1]).
98/// For grayscale images, features must be in 0xGGGGGG format (repeating the grayscale value).
99///
100/// # Threshold Configuration
101///
102/// Models can be configured with different thresholds:
103/// - Anomaly detection: `min_anomaly_score` threshold for visual anomaly detection
104/// - Object detection: `min_score` threshold for object confidence
105/// - Object tracking: `keep_grace`, `max_observations`, and `threshold` parameters
106///
107/// Thresholds can be updated at runtime using `set_learn_block_threshold`.
108pub struct EimModel {
109    /// Path to the Edge Impulse model file (.eim)
110    path: std::path::PathBuf,
111    /// Path to the Unix socket used for IPC
112    socket_path: std::path::PathBuf,
113    /// Handle to the temporary directory for the socket (ensures cleanup)
114    ///
115    /// This field is required to keep the temporary directory alive for the lifetime of the model.
116    /// If dropped, the directory (and socket file) would be deleted, breaking model communication.
117    #[allow(dead_code)]
118    tempdir: Option<TempDir>,
119    /// Active Unix socket connection to the model process
120    socket: UnixStream,
121    /// Enable debug logging of socket communications
122    debug: bool,
123    /// Optional debug callback for receiving debug messages
124    debug_callback: Option<DebugCallback>,
125    /// Handle to the model process (kept alive while model exists)
126    _process: Child,
127    /// Cached model information received during initialization
128    model_info: Option<ModelInfo>,
129    /// Atomic counter for generating unique message IDs
130    message_id: AtomicU32,
131    /// Optional child process handle for restart functionality
132    #[allow(dead_code)]
133    child: Option<Child>,
134    continuous_state: Option<ContinuousState>,
135    model_parameters: ModelParameters,
136}
137
138#[derive(Debug)]
139struct ContinuousState {
140    feature_matrix: Vec<f32>,
141    feature_buffer_full: bool,
142    maf_buffers: HashMap<String, MovingAverageFilter>,
143    slice_size: usize,
144}
145
146impl ContinuousState {
147    fn new(labels: Vec<String>, slice_size: usize) -> Self {
148        Self {
149            feature_matrix: Vec::new(),
150            feature_buffer_full: false,
151            maf_buffers: labels
152                .into_iter()
153                .map(|label| (label, MovingAverageFilter::new(4)))
154                .collect(),
155            slice_size,
156        }
157    }
158
159    fn update_features(&mut self, features: &[f32]) {
160        // Add new features to the matrix
161        self.feature_matrix.extend_from_slice(features);
162
163        // Check if buffer is full
164        if self.feature_matrix.len() >= self.slice_size {
165            self.feature_buffer_full = true;
166            // Keep only the most recent features if we've exceeded the buffer size
167            if self.feature_matrix.len() > self.slice_size {
168                self.feature_matrix
169                    .drain(0..self.feature_matrix.len() - self.slice_size);
170            }
171        }
172    }
173
174    fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
175        for (label, value) in classification.iter_mut() {
176            if let Some(maf) = self.maf_buffers.get_mut(label) {
177                *value = maf.update(*value);
178            }
179        }
180    }
181}
182
183#[derive(Debug)]
184struct MovingAverageFilter {
185    buffer: VecDeque<f32>,
186    window_size: usize,
187    sum: f32,
188}
189
190impl MovingAverageFilter {
191    fn new(window_size: usize) -> Self {
192        Self {
193            buffer: VecDeque::with_capacity(window_size),
194            window_size,
195            sum: 0.0,
196        }
197    }
198
199    fn update(&mut self, value: f32) -> f32 {
200        if self.buffer.len() >= self.window_size {
201            self.sum -= self.buffer.pop_front().unwrap();
202        }
203        self.buffer.push_back(value);
204        self.sum += value;
205        self.sum / self.buffer.len() as f32
206    }
207}
208
209impl fmt::Debug for EimModel {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        f.debug_struct("EimModel")
212            .field("path", &self.path)
213            .field("socket_path", &self.socket_path)
214            .field("socket", &self.socket)
215            .field("debug", &self.debug)
216            .field("_process", &self._process)
217            .field("model_info", &self.model_info)
218            .field("message_id", &self.message_id)
219            .field("child", &self.child)
220            // Skip debug_callback field as it doesn't implement Debug
221            .field("continuous_state", &self.continuous_state)
222            .field("model_parameters", &self.model_parameters)
223            .finish()
224    }
225}
226
227impl EimModel {
228    /// Creates a new EimModel instance from a path to the .eim file.
229    ///
230    /// This is the standard way to create a new model instance. The function will:
231    /// 1. Validate the file extension
232    /// 2. Spawn the model process
233    /// 3. Establish socket communication
234    /// 4. Initialize the model
235    ///
236    /// # Arguments
237    ///
238    /// * `path` - Path to the .eim file. Must be a valid Edge Impulse model file.
239    ///
240    /// # Returns
241    ///
242    /// Returns `Result<EimModel, EimError>` where:
243    /// - `Ok(EimModel)` - Successfully created and initialized model
244    /// - `Err(EimError)` - Failed to create model (invalid path, process spawn failure, etc.)
245    ///
246    /// # Examples
247    ///
248    /// ```no_run
249    /// use edge_impulse_runner::EimModel;
250    ///
251    /// let model = EimModel::new("path/to/model.eim").unwrap();
252    /// ```
253    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
254        Self::new_with_debug(path, false)
255    }
256
257    /// Creates a new EimModel instance with a specific Unix socket path.
258    ///
259    /// Similar to `new()`, but allows specifying the socket path for communication.
260    /// This is useful when you need control over the socket location or when running
261    /// multiple models simultaneously.
262    ///
263    /// # Arguments
264    ///
265    /// * `path` - Path to the .eim file
266    /// * `socket_path` - Custom path where the Unix socket should be created
267    pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
268        path: P,
269        socket_path: S,
270    ) -> Result<Self, EimError> {
271        Self::new_with_socket_and_debug_internal(path, socket_path, false, None)
272    }
273
274    /// Create a new EimModel instance with debug output enabled
275    pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
276        // Create a unique temporary directory for the socket
277        let tempdir = tempdir()
278            .map_err(|e| EimError::ExecutionError(format!("Failed to create tempdir: {e}")))?;
279        // Generate a random socket filename
280        let mut rng = thread_rng();
281        let socket_name = format!("eim_socket_{}", rng.r#gen::<u64>());
282        let socket_path = tempdir.path().join(socket_name);
283        Self::new_with_socket_and_debug_internal(path, &socket_path, debug, Some(tempdir))
284    }
285
286    /// Ensure the model file has execution permissions for the current user
287    fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
288        use std::os::unix::fs::PermissionsExt;
289
290        let path = path.as_ref();
291        let metadata = std::fs::metadata(path)
292            .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {e}")))?;
293
294        let perms = metadata.permissions();
295        let current_mode = perms.mode();
296        if current_mode & 0o100 == 0 {
297            // File is not executable for user, try to make it executable
298            let mut new_perms = perms;
299            new_perms.set_mode(current_mode | 0o100); // Add executable bit for user only
300            std::fs::set_permissions(path, new_perms).map_err(|e| {
301                EimError::ExecutionError(format!("Failed to set executable permissions: {e}"))
302            })?;
303        }
304        Ok(())
305    }
306
307    /// Create a new EimModel instance with debug output enabled and a specific socket path
308    pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
309        path: P,
310        socket_path: S,
311        debug: bool,
312    ) -> Result<Self, EimError> {
313        Self::new_with_socket_and_debug_internal(path, socket_path, debug, None)
314    }
315
316    fn new_with_socket_and_debug_internal<P: AsRef<Path>, S: AsRef<Path>>(
317        path: P,
318        socket_path: S,
319        debug: bool,
320        tempdir: Option<TempDir>,
321    ) -> Result<Self, EimError> {
322        let path = path.as_ref();
323        let socket_path = socket_path.as_ref();
324
325        // Validate file extension
326        if path.extension().and_then(|s| s.to_str()) != Some("eim") {
327            return Err(EimError::InvalidPath);
328        }
329
330        // Convert relative path to absolute path
331        let absolute_path = if path.is_absolute() {
332            path.to_path_buf()
333        } else {
334            std::env::current_dir()
335                .map_err(|_e| EimError::InvalidPath)?
336                .join(path)
337        };
338
339        // Ensure the model file is executable
340        Self::ensure_executable(&absolute_path)?;
341
342        // Start the process
343        let process = std::process::Command::new(&absolute_path)
344            .arg(socket_path)
345            .spawn()
346            .map_err(|e| EimError::ExecutionError(e.to_string()))?;
347
348        let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
349
350        let mut model = Self {
351            path: absolute_path, // Store the absolute path
352            socket_path: socket_path.to_path_buf(),
353            tempdir,
354            socket,
355            debug,
356            _process: process,
357            model_info: None,
358            message_id: AtomicU32::new(1),
359            child: None,
360            debug_callback: None,
361            continuous_state: None,
362            model_parameters: ModelParameters::default(),
363        };
364
365        // Initialize the model by sending hello message
366        model.send_hello()?;
367
368        Ok(model)
369    }
370
371    /// Attempts to connect to the Unix socket with a retry mechanism
372    ///
373    /// This function will repeatedly try to connect to the socket until either:
374    /// - A successful connection is established
375    /// - An unexpected error occurs
376    /// - The timeout duration is exceeded
377    ///
378    /// # Arguments
379    ///
380    /// * `socket_path` - Path to the Unix socket
381    /// * `timeout` - Maximum time to wait for connection
382    fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
383        let start = Instant::now();
384        let retry_interval = Duration::from_millis(50);
385
386        while start.elapsed() < timeout {
387            match UnixStream::connect(socket_path) {
388                Ok(stream) => return Ok(stream),
389                Err(e) => {
390                    // NotFound and ConnectionRefused are expected errors while the socket
391                    // is being created, so we retry in these cases
392                    if e.kind() != std::io::ErrorKind::NotFound
393                        && e.kind() != std::io::ErrorKind::ConnectionRefused
394                    {
395                        return Err(EimError::SocketError(format!(
396                            "Failed to connect to socket: {e}"
397                        )));
398                    }
399                }
400            }
401            std::thread::sleep(retry_interval);
402        }
403
404        Err(EimError::SocketError(format!(
405            "Timeout waiting for socket {} to become available",
406            socket_path.display()
407        )))
408    }
409
410    /// Get the next message ID
411    fn next_message_id(&self) -> u32 {
412        self.message_id.fetch_add(1, Ordering::Relaxed)
413    }
414
415    /// Set a debug callback function to receive debug messages
416    ///
417    /// When debug mode is enabled, this callback will be invoked with debug messages
418    /// from the model runner. This is useful for logging or displaying debug information
419    /// in your application.
420    ///
421    /// # Arguments
422    ///
423    /// * `callback` - Function that takes a string slice and handles the debug message
424    pub fn set_debug_callback<F>(&mut self, callback: F)
425    where
426        F: Fn(&str) + Send + Sync + 'static,
427    {
428        self.debug_callback = Some(Box::new(callback));
429    }
430
431    /// Send debug messages when debug mode is enabled
432    fn debug_message(&self, message: &str) {
433        if self.debug {
434            println!("{message}");
435            if let Some(callback) = &self.debug_callback {
436                callback(message);
437            }
438        }
439    }
440
441    fn send_hello(&mut self) -> Result<(), EimError> {
442        let hello_msg = HelloMessage {
443            hello: 1,
444            id: self.next_message_id(),
445        };
446
447        let msg = serde_json::to_string(&hello_msg)?;
448        self.debug_message(&format!("Sending hello message: {msg}"));
449
450        writeln!(self.socket, "{msg}").map_err(|e| {
451            self.debug_message(&format!("Failed to send hello: {e}"));
452            EimError::SocketError(format!("Failed to send hello message: {e}"))
453        })?;
454
455        self.socket.flush().map_err(|e| {
456            self.debug_message(&format!("Failed to flush hello: {e}"));
457            EimError::SocketError(format!("Failed to flush socket: {e}"))
458        })?;
459
460        self.debug_message("Waiting for hello response...");
461
462        let mut reader = BufReader::new(&self.socket);
463        let mut line = String::new();
464
465        match reader.read_line(&mut line) {
466            Ok(n) => {
467                self.debug_message(&format!("Read {n} bytes: {line}"));
468
469                match serde_json::from_str::<ModelInfo>(&line) {
470                    Ok(info) => {
471                        self.debug_message("Successfully parsed model info");
472                        if !info.success {
473                            self.debug_message("Model initialization failed");
474                            return Err(EimError::ExecutionError(
475                                "Model initialization failed".to_string(),
476                            ));
477                        }
478                        self.debug_message("Got model info response, storing it");
479                        self.model_info = Some(info);
480                        return Ok(());
481                    }
482                    Err(e) => {
483                        self.debug_message(&format!("Failed to parse model info: {e}"));
484                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line)
485                            && !error.success
486                        {
487                            self.debug_message(&format!("Got error response: {error:?}"));
488                            return Err(EimError::ExecutionError(
489                                error.error.unwrap_or_else(|| "Unknown error".to_string()),
490                            ));
491                        }
492                    }
493                }
494            }
495            Err(e) => {
496                self.debug_message(&format!("Failed to read hello response: {e}"));
497                return Err(EimError::SocketError(format!(
498                    "Failed to read response: {e}"
499                )));
500            }
501        }
502
503        self.debug_message("No valid hello response received");
504        Err(EimError::SocketError(
505            "No valid response received".to_string(),
506        ))
507    }
508
509    /// Get the path to the EIM file
510    pub fn path(&self) -> &Path {
511        &self.path
512    }
513
514    /// Get the socket path used for communication
515    pub fn socket_path(&self) -> &Path {
516        &self.socket_path
517    }
518
519    /// Get the sensor type for this model
520    pub fn sensor_type(&self) -> Result<SensorType, EimError> {
521        self.model_info
522            .as_ref()
523            .map(|info| SensorType::from(info.model_parameters.sensor))
524            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
525    }
526
527    /// Get the model parameters
528    pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
529        self.model_info
530            .as_ref()
531            .map(|info| &info.model_parameters)
532            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
533    }
534
535    /// Run inference on the input features
536    ///
537    /// This method automatically handles both continuous and non-continuous modes:
538    ///
539    /// ## Non-Continuous Mode
540    /// - Each call is independent
541    /// - All features must be provided in a single call
542    /// - Results are returned immediately
543    ///
544    /// ## Continuous Mode (automatically enabled for supported models)
545    /// - Features are accumulated across calls
546    /// - Internal buffer maintains sliding window of features
547    /// - Moving average filter smooths results
548    /// - Initial calls may return empty results while buffer fills
549    ///
550    /// # Arguments
551    ///
552    /// * `features` - Vector of input features
553    /// * `debug` - Optional debug flag to enable detailed output for this inference
554    ///
555    /// # Returns
556    ///
557    /// Returns `Result<InferenceResponse, EimError>` containing inference results
558    pub fn infer(
559        &mut self,
560        features: Vec<f32>,
561        debug: Option<bool>,
562    ) -> Result<InferenceResponse, EimError> {
563        // Initialize model info if needed
564        if self.model_info.is_none() {
565            self.send_hello()?;
566        }
567
568        let uses_continuous_mode = self.requires_continuous_mode();
569
570        if uses_continuous_mode {
571            self.infer_continuous_internal(features, debug)
572        } else {
573            self.infer_single(features, debug)
574        }
575    }
576
577    fn infer_continuous_internal(
578        &mut self,
579        features: Vec<f32>,
580        debug: Option<bool>,
581    ) -> Result<InferenceResponse, EimError> {
582        // Initialize continuous state if needed
583        if self.continuous_state.is_none() {
584            let labels = self
585                .model_info
586                .as_ref()
587                .map(|info| info.model_parameters.labels.clone())
588                .unwrap_or_default();
589            let slice_size = self.input_size()?;
590
591            self.continuous_state = Some(ContinuousState::new(labels, slice_size));
592        }
593
594        // Take ownership of state temporarily to avoid multiple mutable borrows
595        let mut state = self.continuous_state.take().unwrap();
596        state.update_features(&features);
597
598        let response = if !state.feature_buffer_full {
599            // Return empty response while building up the buffer
600            Ok(InferenceResponse {
601                success: true,
602                id: self.next_message_id(),
603                result: InferenceResult::Classification {
604                    classification: HashMap::new(),
605                },
606            })
607        } else {
608            // Run inference on the full buffer
609            let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
610
611            // Apply moving average filter to the results
612            if let InferenceResult::Classification {
613                ref mut classification,
614            } = response.result
615            {
616                state.apply_maf(classification);
617            }
618
619            Ok(response)
620        };
621
622        // Restore the state
623        self.continuous_state = Some(state);
624
625        response
626    }
627
628    fn infer_single(
629        &mut self,
630        features: Vec<f32>,
631        debug: Option<bool>,
632    ) -> Result<InferenceResponse, EimError> {
633        // First ensure we've sent the hello message and received model info
634        if self.model_info.is_none() {
635            self.debug_message("No model info, sending hello message...");
636            self.send_hello()?;
637            self.debug_message("Hello handshake completed");
638        }
639
640        let msg = ClassifyMessage {
641            classify: features.clone(),
642            id: self.next_message_id(),
643            debug,
644        };
645
646        let msg_str = serde_json::to_string(&msg)?;
647        self.debug_message(&format!(
648            "Sending inference message with {} features",
649            features.len()
650        ));
651
652        writeln!(self.socket, "{msg_str}").map_err(|e| {
653            self.debug_message(&format!("Failed to send inference message: {e}"));
654            EimError::SocketError(format!("Failed to send inference message: {e}"))
655        })?;
656
657        self.socket.flush().map_err(|e| {
658            self.debug_message(&format!("Failed to flush inference message: {e}"));
659            EimError::SocketError(format!("Failed to flush socket: {e}"))
660        })?;
661
662        self.debug_message("Inference message sent, waiting for response...");
663
664        // Set socket to non-blocking mode
665        self.socket.set_nonblocking(true).map_err(|e| {
666            self.debug_message(&format!("Failed to set non-blocking mode: {e}"));
667            EimError::SocketError(format!("Failed to set non-blocking mode: {e}"))
668        })?;
669
670        let mut reader = BufReader::new(&self.socket);
671        let mut buffer = String::new();
672        let start = Instant::now();
673        let timeout = Duration::from_secs(5);
674
675        while start.elapsed() < timeout {
676            match reader.read_line(&mut buffer) {
677                Ok(0) => {
678                    self.debug_message("EOF reached");
679                    break;
680                }
681                Ok(n) => {
682                    // Skip printing feature values in the response
683                    if !buffer.contains("features:") && !buffer.contains("Features (") {
684                        self.debug_message(&format!("Read {n} bytes: {buffer}"));
685                    }
686
687                    if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer)
688                        && response.success
689                    {
690                        self.debug_message("Got successful inference response");
691                        // Reset to blocking mode before returning
692                        let _ = self.socket.set_nonblocking(false);
693                        return Ok(response);
694                    }
695                    buffer.clear();
696                }
697                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
698                    // No data available yet, sleep briefly and retry
699                    std::thread::sleep(Duration::from_millis(10));
700                    continue;
701                }
702                Err(e) => {
703                    self.debug_message(&format!("Read error: {e}"));
704                    // Always try to reset blocking mode, even on error
705                    let _ = self.socket.set_nonblocking(false);
706                    return Err(EimError::SocketError(format!("Read error: {e}")));
707                }
708            }
709        }
710
711        // Reset to blocking mode before returning
712        let _ = self.socket.set_nonblocking(false);
713        self.debug_message("Timeout reached");
714
715        Err(EimError::ExecutionError(format!(
716            "No valid response received within {} seconds",
717            timeout.as_secs()
718        )))
719    }
720
721    /// Check if model requires continuous mode
722    fn requires_continuous_mode(&self) -> bool {
723        self.model_info
724            .as_ref()
725            .map(|info| info.model_parameters.use_continuous_mode)
726            .unwrap_or(false)
727    }
728
729    /// Get the required number of input features for this model
730    ///
731    /// Returns the number of features expected by the model for each classification.
732    /// This is useful for:
733    /// - Validating input size before classification
734    /// - Preparing the correct amount of data
735    /// - Padding or truncating inputs to match model requirements
736    ///
737    /// # Returns
738    ///
739    /// The number of input features required by the model
740    pub fn input_size(&self) -> Result<usize, EimError> {
741        self.model_info
742            .as_ref()
743            .map(|info| info.model_parameters.input_features_count as usize)
744            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
745    }
746
747    /// Set a threshold for a specific learning block
748    ///
749    /// This method allows updating thresholds for different types of blocks:
750    /// - Anomaly detection (GMM)
751    /// - Object detection
752    /// - Object tracking
753    ///
754    /// # Arguments
755    ///
756    /// * `threshold` - The threshold configuration to set
757    ///
758    /// # Returns
759    ///
760    /// Returns `Result<(), EimError>` indicating success or failure
761    pub async fn set_learn_block_threshold(
762        &mut self,
763        threshold: ThresholdConfig,
764    ) -> Result<(), EimError> {
765        // First check if model info is available and supports thresholds
766        if self.model_info.is_none() {
767            self.debug_message("No model info available, sending hello message...");
768            self.send_hello()?;
769        }
770
771        // Log the current model state
772        if let Some(info) = &self.model_info {
773            self.debug_message(&format!(
774                "Current model type: {}",
775                info.model_parameters.model_type
776            ));
777            self.debug_message(&format!(
778                "Current model parameters: {:?}",
779                info.model_parameters
780            ));
781        }
782
783        let msg = SetThresholdMessage {
784            set_threshold: threshold,
785            id: self.next_message_id(),
786        };
787
788        let msg_str = serde_json::to_string(&msg)?;
789        self.debug_message(&format!("Sending threshold message: {msg_str}"));
790
791        writeln!(self.socket, "{msg_str}").map_err(|e| {
792            self.debug_message(&format!("Failed to send threshold message: {e}"));
793            EimError::SocketError(format!("Failed to send threshold message: {e}"))
794        })?;
795
796        self.socket.flush().map_err(|e| {
797            self.debug_message(&format!("Failed to flush threshold message: {e}"));
798            EimError::SocketError(format!("Failed to flush socket: {e}"))
799        })?;
800
801        let mut reader = BufReader::new(&self.socket);
802        let mut line = String::new();
803
804        match reader.read_line(&mut line) {
805            Ok(_) => {
806                self.debug_message(&format!("Received response: {line}"));
807                match serde_json::from_str::<SetThresholdResponse>(&line) {
808                    Ok(response) => {
809                        if response.success {
810                            self.debug_message("Successfully set threshold");
811                            Ok(())
812                        } else {
813                            self.debug_message("Server reported failure setting threshold");
814                            Err(EimError::ExecutionError(
815                                "Server reported failure setting threshold".to_string(),
816                            ))
817                        }
818                    }
819                    Err(e) => {
820                        self.debug_message(&format!("Failed to parse threshold response: {e}"));
821                        // Try to parse as error response
822                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
823                            Err(EimError::ExecutionError(
824                                error.error.unwrap_or_else(|| "Unknown error".to_string()),
825                            ))
826                        } else {
827                            Err(EimError::ExecutionError(format!(
828                                "Invalid threshold response format: {e}"
829                            )))
830                        }
831                    }
832                }
833            }
834            Err(e) => {
835                self.debug_message(&format!("Failed to read threshold response: {e}"));
836                Err(EimError::SocketError(format!(
837                    "Failed to read response: {e}"
838                )))
839            }
840        }
841    }
842
843    /// Get the minimum anomaly score threshold from model parameters
844    fn get_min_anomaly_score(&self) -> f32 {
845        self.model_info
846            .as_ref()
847            .and_then(|info| {
848                info.model_parameters
849                    .thresholds
850                    .iter()
851                    .find_map(|t| match t {
852                        ModelThreshold::AnomalyGMM {
853                            min_anomaly_score, ..
854                        } => Some(*min_anomaly_score),
855                        _ => None,
856                    })
857            })
858            .unwrap_or(6.0)
859    }
860
861    /// Normalize an anomaly score relative to the model's minimum threshold
862    fn normalize_anomaly_score(&self, score: f32) -> f32 {
863        (score / self.get_min_anomaly_score()).min(1.0)
864    }
865
866    /// Normalize a visual anomaly result
867    pub fn normalize_visual_anomaly(
868        &self,
869        anomaly: f32,
870        max: f32,
871        mean: f32,
872        regions: &[(f32, u32, u32, u32, u32)],
873    ) -> VisualAnomalyResult {
874        let normalized_anomaly = self.normalize_anomaly_score(anomaly);
875        let normalized_max = self.normalize_anomaly_score(max);
876        let normalized_mean = self.normalize_anomaly_score(mean);
877        let normalized_regions: Vec<_> = regions
878            .iter()
879            .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
880            .collect();
881
882        (
883            normalized_anomaly,
884            normalized_max,
885            normalized_mean,
886            normalized_regions,
887        )
888    }
889}