edge_impulse_runner/inference/
model.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::io::{BufRead, BufReader, Write};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::process::Child;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::time::{Duration, Instant};
9
10use crate::error::EimError;
11use crate::inference::messages::{
12    ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
13    SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
14};
15use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
16
17/// Debug callback type for receiving debug messages
18pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20/// Edge Impulse Model Runner for Rust
21///
22/// This module provides functionality for running Edge Impulse machine learning models on Linux systems.
23/// It handles model lifecycle management, communication, and inference operations.
24///
25/// # Key Components
26///
27/// - `EimModel`: Main struct for managing Edge Impulse models
28/// - `SensorType`: Enum representing supported sensor input types
29/// - `ContinuousState`: Internal state management for continuous inference mode
30/// - `MovingAverageFilter`: Smoothing filter for continuous inference results
31///
32/// # Features
33///
34/// - Model process management and Unix socket communication
35/// - Support for both single-shot and continuous inference modes
36/// - Debug logging and callback system
37/// - Moving average filtering for continuous mode results
38/// - Automatic retry mechanisms for socket connections
39/// - Visual anomaly detection (FOMO AD) support with normalized scores
40///
41/// # Example Usage
42///
43/// ```no_run
44/// use edge_impulse_runner::{EimModel, InferenceResult};
45///
46/// // Create a new model instance
47/// let mut model = EimModel::new("path/to/model.eim").unwrap();
48///
49/// // For RGB images, features must be in 0xRRGGBB format (not normalized to [0,1])
50/// let features = vec![0xFF0000 as f32, 0x00FF00 as f32, 0x0000FF as f32]; // Example: Red, Green, Blue pixels
51/// let result = model.infer(features, None).unwrap();
52///
53/// // For visual anomaly detection models, normalize the results
54/// if let InferenceResult::VisualAnomaly { anomaly, visual_anomaly_max, visual_anomaly_mean, visual_anomaly_grid } = result.result {
55///     let (normalized_anomaly, normalized_max, normalized_mean, normalized_regions) =
56///         model.normalize_visual_anomaly(
57///             anomaly,
58///             visual_anomaly_max,
59///             visual_anomaly_mean,
60///             &visual_anomaly_grid.iter()
61///                 .map(|bbox| (bbox.value, bbox.x as u32, bbox.y as u32, bbox.width as u32, bbox.height as u32))
62///                 .collect::<Vec<_>>()
63///         );
64///     println!("Anomaly score: {:.2}%", normalized_anomaly * 100.0);
65/// }
66/// ```
67///
68/// # Communication Protocol
69///
70/// The model communicates with the Edge Impulse process using JSON messages over Unix sockets:
71/// 1. Hello message for initialization
72/// 2. Model info response
73/// 3. Classification requests
74/// 4. Inference responses
75///
76/// # Error Handling
77///
78/// The module uses a custom `EimError` type for error handling, covering:
79/// - Invalid file paths
80/// - Socket communication errors
81/// - Model execution errors
82/// - JSON serialization/deserialization errors
83///
84/// # Visual Anomaly Detection
85///
86/// For visual anomaly detection models (FOMO AD):
87/// - Scores are normalized relative to the model's minimum anomaly threshold
88/// - Results include overall anomaly score, maximum score, mean score, and anomalous regions
89/// - Region coordinates are provided in the original image dimensions
90/// - All scores are clamped to [0,1] range and displayed as percentages
91/// - Debug mode provides detailed information about thresholds and regions
92///
93/// # Image Feature Format
94///
95/// For RGB images, features must be in 0xRRGGBB format (not normalized to [0,1]).
96/// For grayscale images, features must be in 0xGGGGGG format (repeating the grayscale value).
97///
98/// # Threshold Configuration
99///
100/// Models can be configured with different thresholds:
101/// - Anomaly detection: `min_anomaly_score` threshold for visual anomaly detection
102/// - Object detection: `min_score` threshold for object confidence
103/// - Object tracking: `keep_grace`, `max_observations`, and `threshold` parameters
104///
105/// Thresholds can be updated at runtime using `set_learn_block_threshold`.
106pub struct EimModel {
107    /// Path to the Edge Impulse model file (.eim)
108    path: std::path::PathBuf,
109    /// Path to the Unix socket used for IPC
110    socket_path: std::path::PathBuf,
111    /// Active Unix socket connection to the model process
112    socket: UnixStream,
113    /// Enable debug logging of socket communications
114    debug: bool,
115    /// Optional debug callback for receiving debug messages
116    debug_callback: Option<DebugCallback>,
117    /// Handle to the model process (kept alive while model exists)
118    _process: Child,
119    /// Cached model information received during initialization
120    model_info: Option<ModelInfo>,
121    /// Atomic counter for generating unique message IDs
122    message_id: AtomicU32,
123    /// Optional child process handle for restart functionality
124    #[allow(dead_code)]
125    child: Option<Child>,
126    continuous_state: Option<ContinuousState>,
127    model_parameters: ModelParameters,
128}
129
130#[derive(Debug)]
131struct ContinuousState {
132    feature_matrix: Vec<f32>,
133    feature_buffer_full: bool,
134    maf_buffers: HashMap<String, MovingAverageFilter>,
135    slice_size: usize,
136}
137
138impl ContinuousState {
139    fn new(labels: Vec<String>, slice_size: usize) -> Self {
140        Self {
141            feature_matrix: Vec::new(),
142            feature_buffer_full: false,
143            maf_buffers: labels
144                .into_iter()
145                .map(|label| (label, MovingAverageFilter::new(4)))
146                .collect(),
147            slice_size,
148        }
149    }
150
151    fn update_features(&mut self, features: &[f32]) {
152        // Add new features to the matrix
153        self.feature_matrix.extend_from_slice(features);
154
155        // Check if buffer is full
156        if self.feature_matrix.len() >= self.slice_size {
157            self.feature_buffer_full = true;
158            // Keep only the most recent features if we've exceeded the buffer size
159            if self.feature_matrix.len() > self.slice_size {
160                self.feature_matrix
161                    .drain(0..self.feature_matrix.len() - self.slice_size);
162            }
163        }
164    }
165
166    fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
167        for (label, value) in classification.iter_mut() {
168            if let Some(maf) = self.maf_buffers.get_mut(label) {
169                *value = maf.update(*value);
170            }
171        }
172    }
173}
174
175#[derive(Debug)]
176struct MovingAverageFilter {
177    buffer: VecDeque<f32>,
178    window_size: usize,
179    sum: f32,
180}
181
182impl MovingAverageFilter {
183    fn new(window_size: usize) -> Self {
184        Self {
185            buffer: VecDeque::with_capacity(window_size),
186            window_size,
187            sum: 0.0,
188        }
189    }
190
191    fn update(&mut self, value: f32) -> f32 {
192        if self.buffer.len() >= self.window_size {
193            self.sum -= self.buffer.pop_front().unwrap();
194        }
195        self.buffer.push_back(value);
196        self.sum += value;
197        self.sum / self.buffer.len() as f32
198    }
199}
200
201impl fmt::Debug for EimModel {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        f.debug_struct("EimModel")
204            .field("path", &self.path)
205            .field("socket_path", &self.socket_path)
206            .field("socket", &self.socket)
207            .field("debug", &self.debug)
208            .field("_process", &self._process)
209            .field("model_info", &self.model_info)
210            .field("message_id", &self.message_id)
211            .field("child", &self.child)
212            // Skip debug_callback field as it doesn't implement Debug
213            .field("continuous_state", &self.continuous_state)
214            .field("model_parameters", &self.model_parameters)
215            .finish()
216    }
217}
218
219impl EimModel {
220    /// Creates a new EimModel instance from a path to the .eim file.
221    ///
222    /// This is the standard way to create a new model instance. The function will:
223    /// 1. Validate the file extension
224    /// 2. Spawn the model process
225    /// 3. Establish socket communication
226    /// 4. Initialize the model
227    ///
228    /// # Arguments
229    ///
230    /// * `path` - Path to the .eim file. Must be a valid Edge Impulse model file.
231    ///
232    /// # Returns
233    ///
234    /// Returns `Result<EimModel, EimError>` where:
235    /// - `Ok(EimModel)` - Successfully created and initialized model
236    /// - `Err(EimError)` - Failed to create model (invalid path, process spawn failure, etc.)
237    ///
238    /// # Examples
239    ///
240    /// ```no_run
241    /// use edge_impulse_runner::EimModel;
242    ///
243    /// let model = EimModel::new("path/to/model.eim").unwrap();
244    /// ```
245    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
246        Self::new_with_debug(path, false)
247    }
248
249    /// Creates a new EimModel instance with a specific Unix socket path.
250    ///
251    /// Similar to `new()`, but allows specifying the socket path for communication.
252    /// This is useful when you need control over the socket location or when running
253    /// multiple models simultaneously.
254    ///
255    /// # Arguments
256    ///
257    /// * `path` - Path to the .eim file
258    /// * `socket_path` - Custom path where the Unix socket should be created
259    pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
260        path: P,
261        socket_path: S,
262    ) -> Result<Self, EimError> {
263        Self::new_with_socket_and_debug(path, socket_path, false)
264    }
265
266    /// Create a new EimModel instance with debug output enabled
267    pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
268        let socket_path = std::env::temp_dir().join("eim_socket");
269        Self::new_with_socket_and_debug(path, &socket_path, debug)
270    }
271
272    /// Ensure the model file has execution permissions for the current user
273    fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
274        use std::os::unix::fs::PermissionsExt;
275
276        let path = path.as_ref();
277        let metadata = std::fs::metadata(path)
278            .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {e}")))?;
279
280        let perms = metadata.permissions();
281        let current_mode = perms.mode();
282        if current_mode & 0o100 == 0 {
283            // File is not executable for user, try to make it executable
284            let mut new_perms = perms;
285            new_perms.set_mode(current_mode | 0o100); // Add executable bit for user only
286            std::fs::set_permissions(path, new_perms).map_err(|e| {
287                EimError::ExecutionError(format!("Failed to set executable permissions: {e}"))
288            })?;
289        }
290        Ok(())
291    }
292
293    /// Create a new EimModel instance with debug output enabled and a specific socket path
294    pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
295        path: P,
296        socket_path: S,
297        debug: bool,
298    ) -> Result<Self, EimError> {
299        let path = path.as_ref();
300        let socket_path = socket_path.as_ref();
301
302        // Validate file extension
303        if path.extension().and_then(|s| s.to_str()) != Some("eim") {
304            return Err(EimError::InvalidPath);
305        }
306
307        // Convert relative path to absolute path
308        let absolute_path = if path.is_absolute() {
309            path.to_path_buf()
310        } else {
311            std::env::current_dir()
312                .map_err(|_e| EimError::InvalidPath)?
313                .join(path)
314        };
315
316        // Ensure the model file is executable
317        Self::ensure_executable(&absolute_path)?;
318
319        // Start the process
320        let process = std::process::Command::new(&absolute_path)
321            .arg(socket_path)
322            .spawn()
323            .map_err(|e| EimError::ExecutionError(e.to_string()))?;
324
325        let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
326
327        let mut model = Self {
328            path: absolute_path, // Store the absolute path
329            socket_path: socket_path.to_path_buf(),
330            socket,
331            debug,
332            _process: process,
333            model_info: None,
334            message_id: AtomicU32::new(1),
335            child: None,
336            debug_callback: None,
337            continuous_state: None,
338            model_parameters: ModelParameters::default(),
339        };
340
341        // Initialize the model by sending hello message
342        model.send_hello()?;
343
344        Ok(model)
345    }
346
347    /// Attempts to connect to the Unix socket with a retry mechanism
348    ///
349    /// This function will repeatedly try to connect to the socket until either:
350    /// - A successful connection is established
351    /// - An unexpected error occurs
352    /// - The timeout duration is exceeded
353    ///
354    /// # Arguments
355    ///
356    /// * `socket_path` - Path to the Unix socket
357    /// * `timeout` - Maximum time to wait for connection
358    fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
359        let start = Instant::now();
360        let retry_interval = Duration::from_millis(50);
361
362        while start.elapsed() < timeout {
363            match UnixStream::connect(socket_path) {
364                Ok(stream) => return Ok(stream),
365                Err(e) => {
366                    // NotFound and ConnectionRefused are expected errors while the socket
367                    // is being created, so we retry in these cases
368                    if e.kind() != std::io::ErrorKind::NotFound
369                        && e.kind() != std::io::ErrorKind::ConnectionRefused
370                    {
371                        return Err(EimError::SocketError(format!(
372                            "Failed to connect to socket: {e}"
373                        )));
374                    }
375                }
376            }
377            std::thread::sleep(retry_interval);
378        }
379
380        Err(EimError::SocketError(format!(
381            "Timeout waiting for socket {} to become available",
382            socket_path.display()
383        )))
384    }
385
386    /// Get the next message ID
387    fn next_message_id(&self) -> u32 {
388        self.message_id.fetch_add(1, Ordering::Relaxed)
389    }
390
391    /// Set a debug callback function to receive debug messages
392    ///
393    /// When debug mode is enabled, this callback will be invoked with debug messages
394    /// from the model runner. This is useful for logging or displaying debug information
395    /// in your application.
396    ///
397    /// # Arguments
398    ///
399    /// * `callback` - Function that takes a string slice and handles the debug message
400    pub fn set_debug_callback<F>(&mut self, callback: F)
401    where
402        F: Fn(&str) + Send + Sync + 'static,
403    {
404        self.debug_callback = Some(Box::new(callback));
405    }
406
407    /// Send debug messages when debug mode is enabled
408    fn debug_message(&self, message: &str) {
409        if self.debug {
410            println!("{message}");
411            if let Some(callback) = &self.debug_callback {
412                callback(message);
413            }
414        }
415    }
416
417    fn send_hello(&mut self) -> Result<(), EimError> {
418        let hello_msg = HelloMessage {
419            hello: 1,
420            id: self.next_message_id(),
421        };
422
423        let msg = serde_json::to_string(&hello_msg)?;
424        self.debug_message(&format!("Sending hello message: {msg}"));
425
426        writeln!(self.socket, "{msg}").map_err(|e| {
427            self.debug_message(&format!("Failed to send hello: {e}"));
428            EimError::SocketError(format!("Failed to send hello message: {e}"))
429        })?;
430
431        self.socket.flush().map_err(|e| {
432            self.debug_message(&format!("Failed to flush hello: {e}"));
433            EimError::SocketError(format!("Failed to flush socket: {e}"))
434        })?;
435
436        self.debug_message("Waiting for hello response...");
437
438        let mut reader = BufReader::new(&self.socket);
439        let mut line = String::new();
440
441        match reader.read_line(&mut line) {
442            Ok(n) => {
443                self.debug_message(&format!("Read {n} bytes: {line}"));
444
445                match serde_json::from_str::<ModelInfo>(&line) {
446                    Ok(info) => {
447                        self.debug_message("Successfully parsed model info");
448                        if !info.success {
449                            self.debug_message("Model initialization failed");
450                            return Err(EimError::ExecutionError(
451                                "Model initialization failed".to_string(),
452                            ));
453                        }
454                        self.debug_message("Got model info response, storing it");
455                        self.model_info = Some(info);
456                        return Ok(());
457                    }
458                    Err(e) => {
459                        self.debug_message(&format!("Failed to parse model info: {e}"));
460                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
461                            if !error.success {
462                                self.debug_message(&format!("Got error response: {error:?}"));
463                                return Err(EimError::ExecutionError(
464                                    error.error.unwrap_or_else(|| "Unknown error".to_string()),
465                                ));
466                            }
467                        }
468                    }
469                }
470            }
471            Err(e) => {
472                self.debug_message(&format!("Failed to read hello response: {e}"));
473                return Err(EimError::SocketError(format!(
474                    "Failed to read response: {e}"
475                )));
476            }
477        }
478
479        self.debug_message("No valid hello response received");
480        Err(EimError::SocketError(
481            "No valid response received".to_string(),
482        ))
483    }
484
485    /// Get the path to the EIM file
486    pub fn path(&self) -> &Path {
487        &self.path
488    }
489
490    /// Get the socket path used for communication
491    pub fn socket_path(&self) -> &Path {
492        &self.socket_path
493    }
494
495    /// Get the sensor type for this model
496    pub fn sensor_type(&self) -> Result<SensorType, EimError> {
497        self.model_info
498            .as_ref()
499            .map(|info| SensorType::from(info.model_parameters.sensor))
500            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
501    }
502
503    /// Get the model parameters
504    pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
505        self.model_info
506            .as_ref()
507            .map(|info| &info.model_parameters)
508            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
509    }
510
511    /// Run inference on the input features
512    ///
513    /// This method automatically handles both continuous and non-continuous modes:
514    ///
515    /// ## Non-Continuous Mode
516    /// - Each call is independent
517    /// - All features must be provided in a single call
518    /// - Results are returned immediately
519    ///
520    /// ## Continuous Mode (automatically enabled for supported models)
521    /// - Features are accumulated across calls
522    /// - Internal buffer maintains sliding window of features
523    /// - Moving average filter smooths results
524    /// - Initial calls may return empty results while buffer fills
525    ///
526    /// # Arguments
527    ///
528    /// * `features` - Vector of input features
529    /// * `debug` - Optional debug flag to enable detailed output for this inference
530    ///
531    /// # Returns
532    ///
533    /// Returns `Result<InferenceResponse, EimError>` containing inference results
534    pub fn infer(
535        &mut self,
536        features: Vec<f32>,
537        debug: Option<bool>,
538    ) -> Result<InferenceResponse, EimError> {
539        // Initialize model info if needed
540        if self.model_info.is_none() {
541            self.send_hello()?;
542        }
543
544        let uses_continuous_mode = self.requires_continuous_mode();
545
546        if uses_continuous_mode {
547            self.infer_continuous_internal(features, debug)
548        } else {
549            self.infer_single(features, debug)
550        }
551    }
552
553    fn infer_continuous_internal(
554        &mut self,
555        features: Vec<f32>,
556        debug: Option<bool>,
557    ) -> Result<InferenceResponse, EimError> {
558        // Initialize continuous state if needed
559        if self.continuous_state.is_none() {
560            let labels = self
561                .model_info
562                .as_ref()
563                .map(|info| info.model_parameters.labels.clone())
564                .unwrap_or_default();
565            let slice_size = self.input_size()?;
566
567            self.continuous_state = Some(ContinuousState::new(labels, slice_size));
568        }
569
570        // Take ownership of state temporarily to avoid multiple mutable borrows
571        let mut state = self.continuous_state.take().unwrap();
572        state.update_features(&features);
573
574        let response = if !state.feature_buffer_full {
575            // Return empty response while building up the buffer
576            Ok(InferenceResponse {
577                success: true,
578                id: self.next_message_id(),
579                result: InferenceResult::Classification {
580                    classification: HashMap::new(),
581                },
582            })
583        } else {
584            // Run inference on the full buffer
585            let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
586
587            // Apply moving average filter to the results
588            if let InferenceResult::Classification {
589                ref mut classification,
590            } = response.result
591            {
592                state.apply_maf(classification);
593            }
594
595            Ok(response)
596        };
597
598        // Restore the state
599        self.continuous_state = Some(state);
600
601        response
602    }
603
604    fn infer_single(
605        &mut self,
606        features: Vec<f32>,
607        debug: Option<bool>,
608    ) -> Result<InferenceResponse, EimError> {
609        // First ensure we've sent the hello message and received model info
610        if self.model_info.is_none() {
611            self.debug_message("No model info, sending hello message...");
612            self.send_hello()?;
613            self.debug_message("Hello handshake completed");
614        }
615
616        let msg = ClassifyMessage {
617            classify: features.clone(),
618            id: self.next_message_id(),
619            debug,
620        };
621
622        let msg_str = serde_json::to_string(&msg)?;
623        self.debug_message(&format!(
624            "Sending inference message with {} features",
625            features.len()
626        ));
627
628        writeln!(self.socket, "{msg_str}").map_err(|e| {
629            self.debug_message(&format!("Failed to send inference message: {e}"));
630            EimError::SocketError(format!("Failed to send inference message: {e}"))
631        })?;
632
633        self.socket.flush().map_err(|e| {
634            self.debug_message(&format!("Failed to flush inference message: {e}"));
635            EimError::SocketError(format!("Failed to flush socket: {e}"))
636        })?;
637
638        self.debug_message("Inference message sent, waiting for response...");
639
640        // Set socket to non-blocking mode
641        self.socket.set_nonblocking(true).map_err(|e| {
642            self.debug_message(&format!("Failed to set non-blocking mode: {e}"));
643            EimError::SocketError(format!("Failed to set non-blocking mode: {e}"))
644        })?;
645
646        let mut reader = BufReader::new(&self.socket);
647        let mut buffer = String::new();
648        let start = Instant::now();
649        let timeout = Duration::from_secs(5);
650
651        while start.elapsed() < timeout {
652            match reader.read_line(&mut buffer) {
653                Ok(0) => {
654                    self.debug_message("EOF reached");
655                    break;
656                }
657                Ok(n) => {
658                    // Skip printing feature values in the response
659                    if !buffer.contains("features:") && !buffer.contains("Features (") {
660                        self.debug_message(&format!("Read {n} bytes: {buffer}"));
661                    }
662
663                    if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer) {
664                        if response.success {
665                            self.debug_message("Got successful inference response");
666                            // Reset to blocking mode before returning
667                            let _ = self.socket.set_nonblocking(false);
668                            return Ok(response);
669                        }
670                    }
671                    buffer.clear();
672                }
673                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
674                    // No data available yet, sleep briefly and retry
675                    std::thread::sleep(Duration::from_millis(10));
676                    continue;
677                }
678                Err(e) => {
679                    self.debug_message(&format!("Read error: {e}"));
680                    // Always try to reset blocking mode, even on error
681                    let _ = self.socket.set_nonblocking(false);
682                    return Err(EimError::SocketError(format!("Read error: {e}")));
683                }
684            }
685        }
686
687        // Reset to blocking mode before returning
688        let _ = self.socket.set_nonblocking(false);
689        self.debug_message("Timeout reached");
690
691        Err(EimError::ExecutionError(format!(
692            "No valid response received within {} seconds",
693            timeout.as_secs()
694        )))
695    }
696
697    /// Check if model requires continuous mode
698    fn requires_continuous_mode(&self) -> bool {
699        self.model_info
700            .as_ref()
701            .map(|info| info.model_parameters.use_continuous_mode)
702            .unwrap_or(false)
703    }
704
705    /// Get the required number of input features for this model
706    ///
707    /// Returns the number of features expected by the model for each classification.
708    /// This is useful for:
709    /// - Validating input size before classification
710    /// - Preparing the correct amount of data
711    /// - Padding or truncating inputs to match model requirements
712    ///
713    /// # Returns
714    ///
715    /// The number of input features required by the model
716    pub fn input_size(&self) -> Result<usize, EimError> {
717        self.model_info
718            .as_ref()
719            .map(|info| info.model_parameters.input_features_count as usize)
720            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
721    }
722
723    /// Set a threshold for a specific learning block
724    ///
725    /// This method allows updating thresholds for different types of blocks:
726    /// - Anomaly detection (GMM)
727    /// - Object detection
728    /// - Object tracking
729    ///
730    /// # Arguments
731    ///
732    /// * `threshold` - The threshold configuration to set
733    ///
734    /// # Returns
735    ///
736    /// Returns `Result<(), EimError>` indicating success or failure
737    pub async fn set_learn_block_threshold(
738        &mut self,
739        threshold: ThresholdConfig,
740    ) -> Result<(), EimError> {
741        // First check if model info is available and supports thresholds
742        if self.model_info.is_none() {
743            self.debug_message("No model info available, sending hello message...");
744            self.send_hello()?;
745        }
746
747        // Log the current model state
748        if let Some(info) = &self.model_info {
749            self.debug_message(&format!(
750                "Current model type: {}",
751                info.model_parameters.model_type
752            ));
753            self.debug_message(&format!(
754                "Current model parameters: {:?}",
755                info.model_parameters
756            ));
757        }
758
759        let msg = SetThresholdMessage {
760            set_threshold: threshold,
761            id: self.next_message_id(),
762        };
763
764        let msg_str = serde_json::to_string(&msg)?;
765        self.debug_message(&format!("Sending threshold message: {msg_str}"));
766
767        writeln!(self.socket, "{msg_str}").map_err(|e| {
768            self.debug_message(&format!("Failed to send threshold message: {e}"));
769            EimError::SocketError(format!("Failed to send threshold message: {e}"))
770        })?;
771
772        self.socket.flush().map_err(|e| {
773            self.debug_message(&format!("Failed to flush threshold message: {e}"));
774            EimError::SocketError(format!("Failed to flush socket: {e}"))
775        })?;
776
777        let mut reader = BufReader::new(&self.socket);
778        let mut line = String::new();
779
780        match reader.read_line(&mut line) {
781            Ok(_) => {
782                self.debug_message(&format!("Received response: {line}"));
783                match serde_json::from_str::<SetThresholdResponse>(&line) {
784                    Ok(response) => {
785                        if response.success {
786                            self.debug_message("Successfully set threshold");
787                            Ok(())
788                        } else {
789                            self.debug_message("Server reported failure setting threshold");
790                            Err(EimError::ExecutionError(
791                                "Server reported failure setting threshold".to_string(),
792                            ))
793                        }
794                    }
795                    Err(e) => {
796                        self.debug_message(&format!("Failed to parse threshold response: {e}"));
797                        // Try to parse as error response
798                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
799                            Err(EimError::ExecutionError(
800                                error.error.unwrap_or_else(|| "Unknown error".to_string()),
801                            ))
802                        } else {
803                            Err(EimError::ExecutionError(format!(
804                                "Invalid threshold response format: {e}"
805                            )))
806                        }
807                    }
808                }
809            }
810            Err(e) => {
811                self.debug_message(&format!("Failed to read threshold response: {e}"));
812                Err(EimError::SocketError(format!(
813                    "Failed to read response: {e}"
814                )))
815            }
816        }
817    }
818
819    /// Get the minimum anomaly score threshold from model parameters
820    fn get_min_anomaly_score(&self) -> f32 {
821        self.model_info
822            .as_ref()
823            .and_then(|info| {
824                info.model_parameters
825                    .thresholds
826                    .iter()
827                    .find_map(|t| match t {
828                        ModelThreshold::AnomalyGMM {
829                            min_anomaly_score, ..
830                        } => Some(*min_anomaly_score),
831                        _ => None,
832                    })
833            })
834            .unwrap_or(6.0)
835    }
836
837    /// Normalize an anomaly score relative to the model's minimum threshold
838    fn normalize_anomaly_score(&self, score: f32) -> f32 {
839        (score / self.get_min_anomaly_score()).min(1.0)
840    }
841
842    /// Normalize a visual anomaly result
843    pub fn normalize_visual_anomaly(
844        &self,
845        anomaly: f32,
846        max: f32,
847        mean: f32,
848        regions: &[(f32, u32, u32, u32, u32)],
849    ) -> VisualAnomalyResult {
850        let normalized_anomaly = self.normalize_anomaly_score(anomaly);
851        let normalized_max = self.normalize_anomaly_score(max);
852        let normalized_mean = self.normalize_anomaly_score(mean);
853        let normalized_regions: Vec<_> = regions
854            .iter()
855            .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
856            .collect();
857
858        (
859            normalized_anomaly,
860            normalized_max,
861            normalized_mean,
862            normalized_regions,
863        )
864    }
865}