1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::io::{BufRead, BufReader, Write};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::process::Child;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::time::{Duration, Instant};
9
10use crate::error::EimError;
11use crate::inference::messages::{
12 ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
13 SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
14};
15use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
16
17pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub struct EimModel {
107 path: std::path::PathBuf,
109 socket_path: std::path::PathBuf,
111 socket: UnixStream,
113 debug: bool,
115 debug_callback: Option<DebugCallback>,
117 _process: Child,
119 model_info: Option<ModelInfo>,
121 message_id: AtomicU32,
123 #[allow(dead_code)]
125 child: Option<Child>,
126 continuous_state: Option<ContinuousState>,
127 model_parameters: ModelParameters,
128}
129
130#[derive(Debug)]
131struct ContinuousState {
132 feature_matrix: Vec<f32>,
133 feature_buffer_full: bool,
134 maf_buffers: HashMap<String, MovingAverageFilter>,
135 slice_size: usize,
136}
137
138impl ContinuousState {
139 fn new(labels: Vec<String>, slice_size: usize) -> Self {
140 Self {
141 feature_matrix: Vec::new(),
142 feature_buffer_full: false,
143 maf_buffers: labels
144 .into_iter()
145 .map(|label| (label, MovingAverageFilter::new(4)))
146 .collect(),
147 slice_size,
148 }
149 }
150
151 fn update_features(&mut self, features: &[f32]) {
152 self.feature_matrix.extend_from_slice(features);
154
155 if self.feature_matrix.len() >= self.slice_size {
157 self.feature_buffer_full = true;
158 if self.feature_matrix.len() > self.slice_size {
160 self.feature_matrix
161 .drain(0..self.feature_matrix.len() - self.slice_size);
162 }
163 }
164 }
165
166 fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
167 for (label, value) in classification.iter_mut() {
168 if let Some(maf) = self.maf_buffers.get_mut(label) {
169 *value = maf.update(*value);
170 }
171 }
172 }
173}
174
175#[derive(Debug)]
176struct MovingAverageFilter {
177 buffer: VecDeque<f32>,
178 window_size: usize,
179 sum: f32,
180}
181
182impl MovingAverageFilter {
183 fn new(window_size: usize) -> Self {
184 Self {
185 buffer: VecDeque::with_capacity(window_size),
186 window_size,
187 sum: 0.0,
188 }
189 }
190
191 fn update(&mut self, value: f32) -> f32 {
192 if self.buffer.len() >= self.window_size {
193 self.sum -= self.buffer.pop_front().unwrap();
194 }
195 self.buffer.push_back(value);
196 self.sum += value;
197 self.sum / self.buffer.len() as f32
198 }
199}
200
201impl fmt::Debug for EimModel {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("EimModel")
204 .field("path", &self.path)
205 .field("socket_path", &self.socket_path)
206 .field("socket", &self.socket)
207 .field("debug", &self.debug)
208 .field("_process", &self._process)
209 .field("model_info", &self.model_info)
210 .field("message_id", &self.message_id)
211 .field("child", &self.child)
212 .field("continuous_state", &self.continuous_state)
214 .field("model_parameters", &self.model_parameters)
215 .finish()
216 }
217}
218
219impl EimModel {
220 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
246 Self::new_with_debug(path, false)
247 }
248
249 pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
260 path: P,
261 socket_path: S,
262 ) -> Result<Self, EimError> {
263 Self::new_with_socket_and_debug(path, socket_path, false)
264 }
265
266 pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
268 let socket_path = std::env::temp_dir().join("eim_socket");
269 Self::new_with_socket_and_debug(path, &socket_path, debug)
270 }
271
272 fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
274 use std::os::unix::fs::PermissionsExt;
275
276 let path = path.as_ref();
277 let metadata = std::fs::metadata(path)
278 .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {e}")))?;
279
280 let perms = metadata.permissions();
281 let current_mode = perms.mode();
282 if current_mode & 0o100 == 0 {
283 let mut new_perms = perms;
285 new_perms.set_mode(current_mode | 0o100); std::fs::set_permissions(path, new_perms).map_err(|e| {
287 EimError::ExecutionError(format!("Failed to set executable permissions: {e}"))
288 })?;
289 }
290 Ok(())
291 }
292
293 pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
295 path: P,
296 socket_path: S,
297 debug: bool,
298 ) -> Result<Self, EimError> {
299 let path = path.as_ref();
300 let socket_path = socket_path.as_ref();
301
302 if path.extension().and_then(|s| s.to_str()) != Some("eim") {
304 return Err(EimError::InvalidPath);
305 }
306
307 let absolute_path = if path.is_absolute() {
309 path.to_path_buf()
310 } else {
311 std::env::current_dir()
312 .map_err(|_e| EimError::InvalidPath)?
313 .join(path)
314 };
315
316 Self::ensure_executable(&absolute_path)?;
318
319 let process = std::process::Command::new(&absolute_path)
321 .arg(socket_path)
322 .spawn()
323 .map_err(|e| EimError::ExecutionError(e.to_string()))?;
324
325 let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
326
327 let mut model = Self {
328 path: absolute_path, socket_path: socket_path.to_path_buf(),
330 socket,
331 debug,
332 _process: process,
333 model_info: None,
334 message_id: AtomicU32::new(1),
335 child: None,
336 debug_callback: None,
337 continuous_state: None,
338 model_parameters: ModelParameters::default(),
339 };
340
341 model.send_hello()?;
343
344 Ok(model)
345 }
346
347 fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
359 let start = Instant::now();
360 let retry_interval = Duration::from_millis(50);
361
362 while start.elapsed() < timeout {
363 match UnixStream::connect(socket_path) {
364 Ok(stream) => return Ok(stream),
365 Err(e) => {
366 if e.kind() != std::io::ErrorKind::NotFound
369 && e.kind() != std::io::ErrorKind::ConnectionRefused
370 {
371 return Err(EimError::SocketError(format!(
372 "Failed to connect to socket: {e}"
373 )));
374 }
375 }
376 }
377 std::thread::sleep(retry_interval);
378 }
379
380 Err(EimError::SocketError(format!(
381 "Timeout waiting for socket {} to become available",
382 socket_path.display()
383 )))
384 }
385
386 fn next_message_id(&self) -> u32 {
388 self.message_id.fetch_add(1, Ordering::Relaxed)
389 }
390
391 pub fn set_debug_callback<F>(&mut self, callback: F)
401 where
402 F: Fn(&str) + Send + Sync + 'static,
403 {
404 self.debug_callback = Some(Box::new(callback));
405 }
406
407 fn debug_message(&self, message: &str) {
409 if self.debug {
410 println!("{message}");
411 if let Some(callback) = &self.debug_callback {
412 callback(message);
413 }
414 }
415 }
416
417 fn send_hello(&mut self) -> Result<(), EimError> {
418 let hello_msg = HelloMessage {
419 hello: 1,
420 id: self.next_message_id(),
421 };
422
423 let msg = serde_json::to_string(&hello_msg)?;
424 self.debug_message(&format!("Sending hello message: {msg}"));
425
426 writeln!(self.socket, "{msg}").map_err(|e| {
427 self.debug_message(&format!("Failed to send hello: {e}"));
428 EimError::SocketError(format!("Failed to send hello message: {e}"))
429 })?;
430
431 self.socket.flush().map_err(|e| {
432 self.debug_message(&format!("Failed to flush hello: {e}"));
433 EimError::SocketError(format!("Failed to flush socket: {e}"))
434 })?;
435
436 self.debug_message("Waiting for hello response...");
437
438 let mut reader = BufReader::new(&self.socket);
439 let mut line = String::new();
440
441 match reader.read_line(&mut line) {
442 Ok(n) => {
443 self.debug_message(&format!("Read {n} bytes: {line}"));
444
445 match serde_json::from_str::<ModelInfo>(&line) {
446 Ok(info) => {
447 self.debug_message("Successfully parsed model info");
448 if !info.success {
449 self.debug_message("Model initialization failed");
450 return Err(EimError::ExecutionError(
451 "Model initialization failed".to_string(),
452 ));
453 }
454 self.debug_message("Got model info response, storing it");
455 self.model_info = Some(info);
456 return Ok(());
457 }
458 Err(e) => {
459 self.debug_message(&format!("Failed to parse model info: {e}"));
460 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
461 if !error.success {
462 self.debug_message(&format!("Got error response: {error:?}"));
463 return Err(EimError::ExecutionError(
464 error.error.unwrap_or_else(|| "Unknown error".to_string()),
465 ));
466 }
467 }
468 }
469 }
470 }
471 Err(e) => {
472 self.debug_message(&format!("Failed to read hello response: {e}"));
473 return Err(EimError::SocketError(format!(
474 "Failed to read response: {e}"
475 )));
476 }
477 }
478
479 self.debug_message("No valid hello response received");
480 Err(EimError::SocketError(
481 "No valid response received".to_string(),
482 ))
483 }
484
485 pub fn path(&self) -> &Path {
487 &self.path
488 }
489
490 pub fn socket_path(&self) -> &Path {
492 &self.socket_path
493 }
494
495 pub fn sensor_type(&self) -> Result<SensorType, EimError> {
497 self.model_info
498 .as_ref()
499 .map(|info| SensorType::from(info.model_parameters.sensor))
500 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
501 }
502
503 pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
505 self.model_info
506 .as_ref()
507 .map(|info| &info.model_parameters)
508 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
509 }
510
511 pub fn infer(
535 &mut self,
536 features: Vec<f32>,
537 debug: Option<bool>,
538 ) -> Result<InferenceResponse, EimError> {
539 if self.model_info.is_none() {
541 self.send_hello()?;
542 }
543
544 let uses_continuous_mode = self.requires_continuous_mode();
545
546 if uses_continuous_mode {
547 self.infer_continuous_internal(features, debug)
548 } else {
549 self.infer_single(features, debug)
550 }
551 }
552
553 fn infer_continuous_internal(
554 &mut self,
555 features: Vec<f32>,
556 debug: Option<bool>,
557 ) -> Result<InferenceResponse, EimError> {
558 if self.continuous_state.is_none() {
560 let labels = self
561 .model_info
562 .as_ref()
563 .map(|info| info.model_parameters.labels.clone())
564 .unwrap_or_default();
565 let slice_size = self.input_size()?;
566
567 self.continuous_state = Some(ContinuousState::new(labels, slice_size));
568 }
569
570 let mut state = self.continuous_state.take().unwrap();
572 state.update_features(&features);
573
574 let response = if !state.feature_buffer_full {
575 Ok(InferenceResponse {
577 success: true,
578 id: self.next_message_id(),
579 result: InferenceResult::Classification {
580 classification: HashMap::new(),
581 },
582 })
583 } else {
584 let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
586
587 if let InferenceResult::Classification {
589 ref mut classification,
590 } = response.result
591 {
592 state.apply_maf(classification);
593 }
594
595 Ok(response)
596 };
597
598 self.continuous_state = Some(state);
600
601 response
602 }
603
604 fn infer_single(
605 &mut self,
606 features: Vec<f32>,
607 debug: Option<bool>,
608 ) -> Result<InferenceResponse, EimError> {
609 if self.model_info.is_none() {
611 self.debug_message("No model info, sending hello message...");
612 self.send_hello()?;
613 self.debug_message("Hello handshake completed");
614 }
615
616 let msg = ClassifyMessage {
617 classify: features.clone(),
618 id: self.next_message_id(),
619 debug,
620 };
621
622 let msg_str = serde_json::to_string(&msg)?;
623 self.debug_message(&format!(
624 "Sending inference message with {} features",
625 features.len()
626 ));
627
628 writeln!(self.socket, "{msg_str}").map_err(|e| {
629 self.debug_message(&format!("Failed to send inference message: {e}"));
630 EimError::SocketError(format!("Failed to send inference message: {e}"))
631 })?;
632
633 self.socket.flush().map_err(|e| {
634 self.debug_message(&format!("Failed to flush inference message: {e}"));
635 EimError::SocketError(format!("Failed to flush socket: {e}"))
636 })?;
637
638 self.debug_message("Inference message sent, waiting for response...");
639
640 self.socket.set_nonblocking(true).map_err(|e| {
642 self.debug_message(&format!("Failed to set non-blocking mode: {e}"));
643 EimError::SocketError(format!("Failed to set non-blocking mode: {e}"))
644 })?;
645
646 let mut reader = BufReader::new(&self.socket);
647 let mut buffer = String::new();
648 let start = Instant::now();
649 let timeout = Duration::from_secs(5);
650
651 while start.elapsed() < timeout {
652 match reader.read_line(&mut buffer) {
653 Ok(0) => {
654 self.debug_message("EOF reached");
655 break;
656 }
657 Ok(n) => {
658 if !buffer.contains("features:") && !buffer.contains("Features (") {
660 self.debug_message(&format!("Read {n} bytes: {buffer}"));
661 }
662
663 if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer) {
664 if response.success {
665 self.debug_message("Got successful inference response");
666 let _ = self.socket.set_nonblocking(false);
668 return Ok(response);
669 }
670 }
671 buffer.clear();
672 }
673 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
674 std::thread::sleep(Duration::from_millis(10));
676 continue;
677 }
678 Err(e) => {
679 self.debug_message(&format!("Read error: {e}"));
680 let _ = self.socket.set_nonblocking(false);
682 return Err(EimError::SocketError(format!("Read error: {e}")));
683 }
684 }
685 }
686
687 let _ = self.socket.set_nonblocking(false);
689 self.debug_message("Timeout reached");
690
691 Err(EimError::ExecutionError(format!(
692 "No valid response received within {} seconds",
693 timeout.as_secs()
694 )))
695 }
696
697 fn requires_continuous_mode(&self) -> bool {
699 self.model_info
700 .as_ref()
701 .map(|info| info.model_parameters.use_continuous_mode)
702 .unwrap_or(false)
703 }
704
705 pub fn input_size(&self) -> Result<usize, EimError> {
717 self.model_info
718 .as_ref()
719 .map(|info| info.model_parameters.input_features_count as usize)
720 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
721 }
722
723 pub async fn set_learn_block_threshold(
738 &mut self,
739 threshold: ThresholdConfig,
740 ) -> Result<(), EimError> {
741 if self.model_info.is_none() {
743 self.debug_message("No model info available, sending hello message...");
744 self.send_hello()?;
745 }
746
747 if let Some(info) = &self.model_info {
749 self.debug_message(&format!(
750 "Current model type: {}",
751 info.model_parameters.model_type
752 ));
753 self.debug_message(&format!(
754 "Current model parameters: {:?}",
755 info.model_parameters
756 ));
757 }
758
759 let msg = SetThresholdMessage {
760 set_threshold: threshold,
761 id: self.next_message_id(),
762 };
763
764 let msg_str = serde_json::to_string(&msg)?;
765 self.debug_message(&format!("Sending threshold message: {msg_str}"));
766
767 writeln!(self.socket, "{msg_str}").map_err(|e| {
768 self.debug_message(&format!("Failed to send threshold message: {e}"));
769 EimError::SocketError(format!("Failed to send threshold message: {e}"))
770 })?;
771
772 self.socket.flush().map_err(|e| {
773 self.debug_message(&format!("Failed to flush threshold message: {e}"));
774 EimError::SocketError(format!("Failed to flush socket: {e}"))
775 })?;
776
777 let mut reader = BufReader::new(&self.socket);
778 let mut line = String::new();
779
780 match reader.read_line(&mut line) {
781 Ok(_) => {
782 self.debug_message(&format!("Received response: {line}"));
783 match serde_json::from_str::<SetThresholdResponse>(&line) {
784 Ok(response) => {
785 if response.success {
786 self.debug_message("Successfully set threshold");
787 Ok(())
788 } else {
789 self.debug_message("Server reported failure setting threshold");
790 Err(EimError::ExecutionError(
791 "Server reported failure setting threshold".to_string(),
792 ))
793 }
794 }
795 Err(e) => {
796 self.debug_message(&format!("Failed to parse threshold response: {e}"));
797 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
799 Err(EimError::ExecutionError(
800 error.error.unwrap_or_else(|| "Unknown error".to_string()),
801 ))
802 } else {
803 Err(EimError::ExecutionError(format!(
804 "Invalid threshold response format: {e}"
805 )))
806 }
807 }
808 }
809 }
810 Err(e) => {
811 self.debug_message(&format!("Failed to read threshold response: {e}"));
812 Err(EimError::SocketError(format!(
813 "Failed to read response: {e}"
814 )))
815 }
816 }
817 }
818
819 fn get_min_anomaly_score(&self) -> f32 {
821 self.model_info
822 .as_ref()
823 .and_then(|info| {
824 info.model_parameters
825 .thresholds
826 .iter()
827 .find_map(|t| match t {
828 ModelThreshold::AnomalyGMM {
829 min_anomaly_score, ..
830 } => Some(*min_anomaly_score),
831 _ => None,
832 })
833 })
834 .unwrap_or(6.0)
835 }
836
837 fn normalize_anomaly_score(&self, score: f32) -> f32 {
839 (score / self.get_min_anomaly_score()).min(1.0)
840 }
841
842 pub fn normalize_visual_anomaly(
844 &self,
845 anomaly: f32,
846 max: f32,
847 mean: f32,
848 regions: &[(f32, u32, u32, u32, u32)],
849 ) -> VisualAnomalyResult {
850 let normalized_anomaly = self.normalize_anomaly_score(anomaly);
851 let normalized_max = self.normalize_anomaly_score(max);
852 let normalized_mean = self.normalize_anomaly_score(mean);
853 let normalized_regions: Vec<_> = regions
854 .iter()
855 .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
856 .collect();
857
858 (
859 normalized_anomaly,
860 normalized_max,
861 normalized_mean,
862 normalized_regions,
863 )
864 }
865}