edge_impulse_runner/backends/
eim.rs

1//! EIM backend implementation
2//!
3//! This module provides the EIM backend that communicates with Edge Impulse
4//! binary files over Unix sockets.
5
6use super::{BackendConfig, InferenceBackend};
7use crate::error::EdgeImpulseError;
8use crate::inference::messages::InferenceResponse;
9use crate::types::{ModelParameters, SensorType, VisualAnomalyResult};
10use rand::{Rng, thread_rng};
11// Removed unused import
12use std::io::{BufRead, BufReader, Write};
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::process::Child;
16use std::sync::atomic::{AtomicU32, Ordering};
17use std::time::{Duration, Instant};
18use tempfile::{TempDir, tempdir};
19
20use crate::inference::messages::{ClassifyMessage, HelloMessage, InferenceResult, ModelInfo};
21
22/// Debug callback type for receiving debug messages
23pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
24
25/// EIM backend implementation for socket-based communication
26pub struct EimBackend {
27    /// Path to the Edge Impulse model file (.eim)
28    #[allow(dead_code)]
29    path: std::path::PathBuf,
30    /// Path to the Unix socket used for IPC
31    #[allow(dead_code)]
32    socket_path: std::path::PathBuf,
33    /// Handle to the temporary directory for the socket (ensures cleanup)
34    #[allow(dead_code)]
35    tempdir: Option<TempDir>,
36    /// Active Unix socket connection to the model process
37    socket: UnixStream,
38    /// Enable debug logging of socket communications
39    #[allow(dead_code)]
40    debug: bool,
41    /// Optional debug callback for receiving debug messages
42    debug_callback: Option<DebugCallback>,
43    /// Handle to the model process (kept alive while model exists)
44    _process: Child,
45    /// Cached model information received during initialization
46    model_info: Option<ModelInfo>,
47    /// Atomic counter for generating unique message IDs
48    message_id: AtomicU32,
49    /// Model parameters extracted from model info
50    model_parameters: ModelParameters,
51}
52
53impl EimBackend {
54    /// Create a new EIM backend
55    pub fn new(config: BackendConfig) -> Result<Self, EdgeImpulseError> {
56        let BackendConfig::Eim { path, .. } = config else {
57            return Err(EdgeImpulseError::InvalidOperation(
58                "Invalid config type for EIM backend".to_string(),
59            ));
60        };
61
62        // Always generate a temp socket path
63        let tempdir = tempdir().map_err(|e| {
64            EdgeImpulseError::ExecutionError(format!("Failed to create tempdir: {e}"))
65        })?;
66        let mut rng = thread_rng();
67        let socket_name = format!("eim_socket_{}", rng.r#gen::<u64>());
68        let socket_path = tempdir.path().join(socket_name);
69
70        // Ensure the model file has execution permissions
71        Self::ensure_executable(&path)?;
72
73        // Start the model process with the socket path as the first positional argument
74        println!(
75            "Starting EIM process: {} {}",
76            path.display(),
77            socket_path.display()
78        );
79        let process = std::process::Command::new(&path)
80            .arg(&socket_path)
81            .spawn()
82            .map_err(|e| {
83                EdgeImpulseError::ExecutionError(format!("Failed to start model process: {e}"))
84            })?;
85
86        // Wait for the socket to be created and connect
87        let socket = Self::connect_with_retry(&socket_path, Duration::from_secs(10))?;
88
89        let mut backend = Self {
90            path,
91            socket_path,
92            tempdir: Some(tempdir),
93            socket,
94            debug: false,
95            debug_callback: None,
96            _process: process,
97            model_info: None,
98            message_id: AtomicU32::new(1),
99            model_parameters: ModelParameters::default(),
100        };
101
102        // Send hello message to get model info
103        backend.send_hello()?;
104
105        Ok(backend)
106    }
107
108    /// Ensure the model file has execution permissions for the current user
109    fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EdgeImpulseError> {
110        use std::os::unix::fs::PermissionsExt;
111
112        let path = path.as_ref();
113        let metadata = std::fs::metadata(path).map_err(|e| {
114            EdgeImpulseError::ExecutionError(format!("Failed to get file metadata: {e}"))
115        })?;
116
117        let perms = metadata.permissions();
118        let current_mode = perms.mode();
119        if current_mode & 0o100 == 0 {
120            // File is not executable for user, try to make it executable
121            let mut new_perms = perms;
122            new_perms.set_mode(current_mode | 0o100); // Add executable bit for user only
123            std::fs::set_permissions(path, new_perms).map_err(|e| {
124                EdgeImpulseError::ExecutionError(format!(
125                    "Failed to set executable permissions: {e}"
126                ))
127            })?;
128        }
129        Ok(())
130    }
131
132    /// Connect to the socket with retry logic
133    fn connect_with_retry(
134        socket_path: &Path,
135        timeout: Duration,
136    ) -> Result<UnixStream, EdgeImpulseError> {
137        println!("Attempting to connect to socket: {}", socket_path.display());
138        let start = Instant::now();
139        while start.elapsed() < timeout {
140            match UnixStream::connect(socket_path) {
141                Ok(socket) => {
142                    println!("Successfully connected to socket");
143                    return Ok(socket);
144                }
145                Err(_e) => {
146                    std::thread::sleep(Duration::from_millis(100));
147                }
148            }
149        }
150        Err(EdgeImpulseError::SocketError(format!(
151            "Timeout waiting for socket {} to become available",
152            socket_path.display()
153        )))
154    }
155
156    /// Send hello message to get model information
157    fn send_hello(&mut self) -> Result<(), EdgeImpulseError> {
158        let hello = HelloMessage {
159            id: self.next_message_id(),
160            hello: 1,
161        };
162
163        let hello_json = serde_json::to_string(&hello).map_err(|e| {
164            EdgeImpulseError::InvalidOperation(format!("Failed to serialize hello: {e}"))
165        })?;
166
167        self.debug_message(&format!("Sending hello: {hello_json}"));
168
169        // Send the message
170        self.socket
171            .write_all(hello_json.as_bytes())
172            .map_err(|e| EdgeImpulseError::ExecutionError(format!("Failed to send hello: {e}")))?;
173        self.socket.write_all(b"\n").map_err(|e| {
174            EdgeImpulseError::ExecutionError(format!("Failed to send newline: {e}"))
175        })?;
176
177        // Read the response
178        let mut reader = BufReader::new(&self.socket);
179        let mut response = String::new();
180        reader.read_line(&mut response).map_err(|e| {
181            EdgeImpulseError::ExecutionError(format!("Failed to read hello response: {e}"))
182        })?;
183
184        self.debug_message(&format!("Received hello response: {}", response.trim()));
185
186        // Parse the response
187        let model_info: ModelInfo = serde_json::from_str(&response).map_err(|e| {
188            EdgeImpulseError::InvalidOperation(format!("Failed to parse hello response: {e}"))
189        })?;
190
191        self.model_info = Some(model_info.clone());
192
193        // Extract model parameters from the model info
194        self.model_parameters = model_info.model_parameters;
195
196        Ok(())
197    }
198
199    /// Generate the next unique message ID
200    fn next_message_id(&self) -> u32 {
201        self.message_id.fetch_add(1, Ordering::SeqCst)
202    }
203
204    /// Set the debug callback
205    pub fn set_debug_callback(&mut self, callback: DebugCallback) {
206        self.debug_callback = Some(callback);
207    }
208
209    /// Send a debug message if a callback is set
210    fn debug_message(&self, msg: &str) {
211        if let Some(callback) = &self.debug_callback {
212            callback(msg);
213        }
214    }
215
216    /// Classify a single input
217    fn classify(&mut self, input: &[f32]) -> Result<InferenceResult, EdgeImpulseError> {
218        let classify = ClassifyMessage {
219            id: self.next_message_id(),
220            classify: input.to_vec(),
221            debug: None,
222        };
223
224        let classify_json = serde_json::to_string(&classify).map_err(|e| {
225            EdgeImpulseError::InvalidOperation(format!("Failed to serialize classify: {e}"))
226        })?;
227
228        self.socket
229            .write_all(classify_json.as_bytes())
230            .map_err(|e| {
231                EdgeImpulseError::ExecutionError(format!("Failed to send classify: {e}"))
232            })?;
233        self.socket.write_all(b"\n").map_err(|e| {
234            EdgeImpulseError::ExecutionError(format!("Failed to send newline: {e}"))
235        })?;
236
237        let mut reader = BufReader::new(&self.socket);
238        let mut response_json = String::new();
239        reader.read_line(&mut response_json).map_err(|e| {
240            EdgeImpulseError::ExecutionError(format!("Failed to read classify response: {e}"))
241        })?;
242
243        let response: InferenceResponse = match serde_json::from_str(&response_json) {
244            Ok(r) => r,
245            Err(e) => {
246                eprintln!(
247                    "[EIM backend] Failed to parse classify response: {}\nRaw response: {}",
248                    e,
249                    response_json.trim()
250                );
251                return Err(EdgeImpulseError::InvalidOperation(format!(
252                    "Failed to parse classify response: {e}"
253                )));
254            }
255        };
256
257        Ok(response.result)
258    }
259}
260
261impl InferenceBackend for EimBackend {
262    fn new(config: BackendConfig) -> Result<Self, EdgeImpulseError> {
263        EimBackend::new(config)
264    }
265
266    fn infer(
267        &mut self,
268        features: Vec<f32>,
269        _debug: Option<bool>,
270    ) -> Result<InferenceResponse, EdgeImpulseError> {
271        // Use classify and wrap in InferenceResponse
272        let result = self.classify(&features)?;
273        Ok(InferenceResponse {
274            success: true,
275            id: self.next_message_id(),
276            result,
277        })
278    }
279
280    fn parameters(&self) -> Result<&ModelParameters, EdgeImpulseError> {
281        Ok(&self.model_parameters)
282    }
283
284    fn sensor_type(&self) -> Result<SensorType, EdgeImpulseError> {
285        // Convert from i32 to SensorType
286        Ok(SensorType::from(self.model_parameters.sensor))
287    }
288
289    fn input_size(&self) -> Result<usize, EdgeImpulseError> {
290        Ok(self.model_parameters.input_features_count as usize)
291    }
292
293    fn set_debug_callback(&mut self, callback: Box<dyn Fn(&str) + Send + Sync>) {
294        self.set_debug_callback(callback);
295    }
296
297    fn normalize_visual_anomaly(
298        &self,
299        anomaly: f32,
300        max: f32,
301        mean: f32,
302        regions: &[(f32, u32, u32, u32, u32)],
303    ) -> VisualAnomalyResult {
304        (anomaly, max, mean, regions.to_vec())
305    }
306}