edge_impulse_runner/backends/
eim.rs1use super::{BackendConfig, InferenceBackend};
7use crate::error::EdgeImpulseError;
8use crate::inference::messages::InferenceResponse;
9use crate::types::{ModelParameters, SensorType, VisualAnomalyResult};
10use rand::{Rng, thread_rng};
11use std::io::{BufRead, BufReader, Write};
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::process::Child;
16use std::sync::atomic::{AtomicU32, Ordering};
17use std::time::{Duration, Instant};
18use tempfile::{TempDir, tempdir};
19
20use crate::inference::messages::{ClassifyMessage, HelloMessage, InferenceResult, ModelInfo};
21
22pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
24
25pub struct EimBackend {
27 #[allow(dead_code)]
29 path: std::path::PathBuf,
30 #[allow(dead_code)]
32 socket_path: std::path::PathBuf,
33 #[allow(dead_code)]
35 tempdir: Option<TempDir>,
36 socket: UnixStream,
38 #[allow(dead_code)]
40 debug: bool,
41 debug_callback: Option<DebugCallback>,
43 _process: Child,
45 model_info: Option<ModelInfo>,
47 message_id: AtomicU32,
49 model_parameters: ModelParameters,
51}
52
53impl EimBackend {
54 pub fn new(config: BackendConfig) -> Result<Self, EdgeImpulseError> {
56 let BackendConfig::Eim { path, .. } = config else {
57 return Err(EdgeImpulseError::InvalidOperation(
58 "Invalid config type for EIM backend".to_string(),
59 ));
60 };
61
62 let tempdir = tempdir().map_err(|e| {
64 EdgeImpulseError::ExecutionError(format!("Failed to create tempdir: {e}"))
65 })?;
66 let mut rng = thread_rng();
67 let socket_name = format!("eim_socket_{}", rng.r#gen::<u64>());
68 let socket_path = tempdir.path().join(socket_name);
69
70 Self::ensure_executable(&path)?;
72
73 println!(
75 "Starting EIM process: {} {}",
76 path.display(),
77 socket_path.display()
78 );
79 let process = std::process::Command::new(&path)
80 .arg(&socket_path)
81 .spawn()
82 .map_err(|e| {
83 EdgeImpulseError::ExecutionError(format!("Failed to start model process: {e}"))
84 })?;
85
86 let socket = Self::connect_with_retry(&socket_path, Duration::from_secs(10))?;
88
89 let mut backend = Self {
90 path,
91 socket_path,
92 tempdir: Some(tempdir),
93 socket,
94 debug: false,
95 debug_callback: None,
96 _process: process,
97 model_info: None,
98 message_id: AtomicU32::new(1),
99 model_parameters: ModelParameters::default(),
100 };
101
102 backend.send_hello()?;
104
105 Ok(backend)
106 }
107
108 fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EdgeImpulseError> {
110 use std::os::unix::fs::PermissionsExt;
111
112 let path = path.as_ref();
113 let metadata = std::fs::metadata(path).map_err(|e| {
114 EdgeImpulseError::ExecutionError(format!("Failed to get file metadata: {e}"))
115 })?;
116
117 let perms = metadata.permissions();
118 let current_mode = perms.mode();
119 if current_mode & 0o100 == 0 {
120 let mut new_perms = perms;
122 new_perms.set_mode(current_mode | 0o100); std::fs::set_permissions(path, new_perms).map_err(|e| {
124 EdgeImpulseError::ExecutionError(format!(
125 "Failed to set executable permissions: {e}"
126 ))
127 })?;
128 }
129 Ok(())
130 }
131
132 fn connect_with_retry(
134 socket_path: &Path,
135 timeout: Duration,
136 ) -> Result<UnixStream, EdgeImpulseError> {
137 println!("Attempting to connect to socket: {}", socket_path.display());
138 let start = Instant::now();
139 while start.elapsed() < timeout {
140 match UnixStream::connect(socket_path) {
141 Ok(socket) => {
142 println!("Successfully connected to socket");
143 return Ok(socket);
144 }
145 Err(_e) => {
146 std::thread::sleep(Duration::from_millis(100));
147 }
148 }
149 }
150 Err(EdgeImpulseError::SocketError(format!(
151 "Timeout waiting for socket {} to become available",
152 socket_path.display()
153 )))
154 }
155
156 fn send_hello(&mut self) -> Result<(), EdgeImpulseError> {
158 let hello = HelloMessage {
159 id: self.next_message_id(),
160 hello: 1,
161 };
162
163 let hello_json = serde_json::to_string(&hello).map_err(|e| {
164 EdgeImpulseError::InvalidOperation(format!("Failed to serialize hello: {e}"))
165 })?;
166
167 self.debug_message(&format!("Sending hello: {hello_json}"));
168
169 self.socket
171 .write_all(hello_json.as_bytes())
172 .map_err(|e| EdgeImpulseError::ExecutionError(format!("Failed to send hello: {e}")))?;
173 self.socket.write_all(b"\n").map_err(|e| {
174 EdgeImpulseError::ExecutionError(format!("Failed to send newline: {e}"))
175 })?;
176
177 let mut reader = BufReader::new(&self.socket);
179 let mut response = String::new();
180 reader.read_line(&mut response).map_err(|e| {
181 EdgeImpulseError::ExecutionError(format!("Failed to read hello response: {e}"))
182 })?;
183
184 self.debug_message(&format!("Received hello response: {}", response.trim()));
185
186 let model_info: ModelInfo = serde_json::from_str(&response).map_err(|e| {
188 EdgeImpulseError::InvalidOperation(format!("Failed to parse hello response: {e}"))
189 })?;
190
191 self.model_info = Some(model_info.clone());
192
193 self.model_parameters = model_info.model_parameters;
195
196 Ok(())
197 }
198
199 fn next_message_id(&self) -> u32 {
201 self.message_id.fetch_add(1, Ordering::SeqCst)
202 }
203
204 pub fn set_debug_callback(&mut self, callback: DebugCallback) {
206 self.debug_callback = Some(callback);
207 }
208
209 fn debug_message(&self, msg: &str) {
211 if let Some(callback) = &self.debug_callback {
212 callback(msg);
213 }
214 }
215
216 fn classify(&mut self, input: &[f32]) -> Result<InferenceResult, EdgeImpulseError> {
218 let classify = ClassifyMessage {
219 id: self.next_message_id(),
220 classify: input.to_vec(),
221 debug: None,
222 };
223
224 let classify_json = serde_json::to_string(&classify).map_err(|e| {
225 EdgeImpulseError::InvalidOperation(format!("Failed to serialize classify: {e}"))
226 })?;
227
228 self.socket
229 .write_all(classify_json.as_bytes())
230 .map_err(|e| {
231 EdgeImpulseError::ExecutionError(format!("Failed to send classify: {e}"))
232 })?;
233 self.socket.write_all(b"\n").map_err(|e| {
234 EdgeImpulseError::ExecutionError(format!("Failed to send newline: {e}"))
235 })?;
236
237 let mut reader = BufReader::new(&self.socket);
238 let mut response_json = String::new();
239 reader.read_line(&mut response_json).map_err(|e| {
240 EdgeImpulseError::ExecutionError(format!("Failed to read classify response: {e}"))
241 })?;
242
243 let response: InferenceResponse = match serde_json::from_str(&response_json) {
244 Ok(r) => r,
245 Err(e) => {
246 eprintln!(
247 "[EIM backend] Failed to parse classify response: {}\nRaw response: {}",
248 e,
249 response_json.trim()
250 );
251 return Err(EdgeImpulseError::InvalidOperation(format!(
252 "Failed to parse classify response: {e}"
253 )));
254 }
255 };
256
257 Ok(response.result)
258 }
259}
260
261impl InferenceBackend for EimBackend {
262 fn new(config: BackendConfig) -> Result<Self, EdgeImpulseError> {
263 EimBackend::new(config)
264 }
265
266 fn infer(
267 &mut self,
268 features: Vec<f32>,
269 _debug: Option<bool>,
270 ) -> Result<InferenceResponse, EdgeImpulseError> {
271 let result = self.classify(&features)?;
273 Ok(InferenceResponse {
274 success: true,
275 id: self.next_message_id(),
276 result,
277 })
278 }
279
280 fn parameters(&self) -> Result<&ModelParameters, EdgeImpulseError> {
281 Ok(&self.model_parameters)
282 }
283
284 fn sensor_type(&self) -> Result<SensorType, EdgeImpulseError> {
285 Ok(SensorType::from(self.model_parameters.sensor))
287 }
288
289 fn input_size(&self) -> Result<usize, EdgeImpulseError> {
290 Ok(self.model_parameters.input_features_count as usize)
291 }
292
293 fn set_debug_callback(&mut self, callback: Box<dyn Fn(&str) + Send + Sync>) {
294 self.set_debug_callback(callback);
295 }
296
297 fn normalize_visual_anomaly(
298 &self,
299 anomaly: f32,
300 max: f32,
301 mean: f32,
302 regions: &[(f32, u32, u32, u32, u32)],
303 ) -> VisualAnomalyResult {
304 (anomaly, max, mean, regions.to_vec())
305 }
306}