1use rand::{Rng, thread_rng};
2use std::collections::{HashMap, VecDeque};
3use std::fmt;
4use std::io::{BufRead, BufReader, Write};
5use std::os::unix::net::UnixStream;
6use std::path::Path;
7use std::process::Child;
8use std::sync::atomic::{AtomicU32, Ordering};
9use std::time::{Duration, Instant};
10use tempfile::{TempDir, tempdir};
11
12use crate::error::EimError;
13use crate::inference::messages::{
14 ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
15 SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
16};
17use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
18
19pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
21
22pub struct EimModel {
109 path: std::path::PathBuf,
111 socket_path: std::path::PathBuf,
113 #[allow(dead_code)]
118 tempdir: Option<TempDir>,
119 socket: UnixStream,
121 debug: bool,
123 debug_callback: Option<DebugCallback>,
125 _process: Child,
127 model_info: Option<ModelInfo>,
129 message_id: AtomicU32,
131 #[allow(dead_code)]
133 child: Option<Child>,
134 continuous_state: Option<ContinuousState>,
135 model_parameters: ModelParameters,
136}
137
138#[derive(Debug)]
139struct ContinuousState {
140 feature_matrix: Vec<f32>,
141 feature_buffer_full: bool,
142 maf_buffers: HashMap<String, MovingAverageFilter>,
143 slice_size: usize,
144}
145
146impl ContinuousState {
147 fn new(labels: Vec<String>, slice_size: usize) -> Self {
148 Self {
149 feature_matrix: Vec::new(),
150 feature_buffer_full: false,
151 maf_buffers: labels
152 .into_iter()
153 .map(|label| (label, MovingAverageFilter::new(4)))
154 .collect(),
155 slice_size,
156 }
157 }
158
159 fn update_features(&mut self, features: &[f32]) {
160 self.feature_matrix.extend_from_slice(features);
162
163 if self.feature_matrix.len() >= self.slice_size {
165 self.feature_buffer_full = true;
166 if self.feature_matrix.len() > self.slice_size {
168 self.feature_matrix
169 .drain(0..self.feature_matrix.len() - self.slice_size);
170 }
171 }
172 }
173
174 fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
175 for (label, value) in classification.iter_mut() {
176 if let Some(maf) = self.maf_buffers.get_mut(label) {
177 *value = maf.update(*value);
178 }
179 }
180 }
181}
182
183#[derive(Debug)]
184struct MovingAverageFilter {
185 buffer: VecDeque<f32>,
186 window_size: usize,
187 sum: f32,
188}
189
190impl MovingAverageFilter {
191 fn new(window_size: usize) -> Self {
192 Self {
193 buffer: VecDeque::with_capacity(window_size),
194 window_size,
195 sum: 0.0,
196 }
197 }
198
199 fn update(&mut self, value: f32) -> f32 {
200 if self.buffer.len() >= self.window_size {
201 self.sum -= self.buffer.pop_front().unwrap();
202 }
203 self.buffer.push_back(value);
204 self.sum += value;
205 self.sum / self.buffer.len() as f32
206 }
207}
208
209impl fmt::Debug for EimModel {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_struct("EimModel")
212 .field("path", &self.path)
213 .field("socket_path", &self.socket_path)
214 .field("socket", &self.socket)
215 .field("debug", &self.debug)
216 .field("_process", &self._process)
217 .field("model_info", &self.model_info)
218 .field("message_id", &self.message_id)
219 .field("child", &self.child)
220 .field("continuous_state", &self.continuous_state)
222 .field("model_parameters", &self.model_parameters)
223 .finish()
224 }
225}
226
227impl EimModel {
228 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
254 Self::new_with_debug(path, false)
255 }
256
257 pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
268 path: P,
269 socket_path: S,
270 ) -> Result<Self, EimError> {
271 Self::new_with_socket_and_debug_internal(path, socket_path, false, None)
272 }
273
274 pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
276 let tempdir = tempdir()
278 .map_err(|e| EimError::ExecutionError(format!("Failed to create tempdir: {e}")))?;
279 let mut rng = thread_rng();
281 let socket_name = format!("eim_socket_{}", rng.r#gen::<u64>());
282 let socket_path = tempdir.path().join(socket_name);
283 Self::new_with_socket_and_debug_internal(path, &socket_path, debug, Some(tempdir))
284 }
285
286 fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
288 use std::os::unix::fs::PermissionsExt;
289
290 let path = path.as_ref();
291 let metadata = std::fs::metadata(path)
292 .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {e}")))?;
293
294 let perms = metadata.permissions();
295 let current_mode = perms.mode();
296 if current_mode & 0o100 == 0 {
297 let mut new_perms = perms;
299 new_perms.set_mode(current_mode | 0o100); std::fs::set_permissions(path, new_perms).map_err(|e| {
301 EimError::ExecutionError(format!("Failed to set executable permissions: {e}"))
302 })?;
303 }
304 Ok(())
305 }
306
307 pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
309 path: P,
310 socket_path: S,
311 debug: bool,
312 ) -> Result<Self, EimError> {
313 Self::new_with_socket_and_debug_internal(path, socket_path, debug, None)
314 }
315
316 fn new_with_socket_and_debug_internal<P: AsRef<Path>, S: AsRef<Path>>(
317 path: P,
318 socket_path: S,
319 debug: bool,
320 tempdir: Option<TempDir>,
321 ) -> Result<Self, EimError> {
322 let path = path.as_ref();
323 let socket_path = socket_path.as_ref();
324
325 if path.extension().and_then(|s| s.to_str()) != Some("eim") {
327 return Err(EimError::InvalidPath);
328 }
329
330 let absolute_path = if path.is_absolute() {
332 path.to_path_buf()
333 } else {
334 std::env::current_dir()
335 .map_err(|_e| EimError::InvalidPath)?
336 .join(path)
337 };
338
339 Self::ensure_executable(&absolute_path)?;
341
342 let process = std::process::Command::new(&absolute_path)
344 .arg(socket_path)
345 .spawn()
346 .map_err(|e| EimError::ExecutionError(e.to_string()))?;
347
348 let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
349
350 let mut model = Self {
351 path: absolute_path, socket_path: socket_path.to_path_buf(),
353 tempdir,
354 socket,
355 debug,
356 _process: process,
357 model_info: None,
358 message_id: AtomicU32::new(1),
359 child: None,
360 debug_callback: None,
361 continuous_state: None,
362 model_parameters: ModelParameters::default(),
363 };
364
365 model.send_hello()?;
367
368 Ok(model)
369 }
370
371 fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
383 let start = Instant::now();
384 let retry_interval = Duration::from_millis(50);
385
386 while start.elapsed() < timeout {
387 match UnixStream::connect(socket_path) {
388 Ok(stream) => return Ok(stream),
389 Err(e) => {
390 if e.kind() != std::io::ErrorKind::NotFound
393 && e.kind() != std::io::ErrorKind::ConnectionRefused
394 {
395 return Err(EimError::SocketError(format!(
396 "Failed to connect to socket: {e}"
397 )));
398 }
399 }
400 }
401 std::thread::sleep(retry_interval);
402 }
403
404 Err(EimError::SocketError(format!(
405 "Timeout waiting for socket {} to become available",
406 socket_path.display()
407 )))
408 }
409
410 fn next_message_id(&self) -> u32 {
412 self.message_id.fetch_add(1, Ordering::Relaxed)
413 }
414
415 pub fn set_debug_callback<F>(&mut self, callback: F)
425 where
426 F: Fn(&str) + Send + Sync + 'static,
427 {
428 self.debug_callback = Some(Box::new(callback));
429 }
430
431 fn debug_message(&self, message: &str) {
433 if self.debug {
434 println!("{message}");
435 if let Some(callback) = &self.debug_callback {
436 callback(message);
437 }
438 }
439 }
440
441 fn send_hello(&mut self) -> Result<(), EimError> {
442 let hello_msg = HelloMessage {
443 hello: 1,
444 id: self.next_message_id(),
445 };
446
447 let msg = serde_json::to_string(&hello_msg)?;
448 self.debug_message(&format!("Sending hello message: {msg}"));
449
450 writeln!(self.socket, "{msg}").map_err(|e| {
451 self.debug_message(&format!("Failed to send hello: {e}"));
452 EimError::SocketError(format!("Failed to send hello message: {e}"))
453 })?;
454
455 self.socket.flush().map_err(|e| {
456 self.debug_message(&format!("Failed to flush hello: {e}"));
457 EimError::SocketError(format!("Failed to flush socket: {e}"))
458 })?;
459
460 self.debug_message("Waiting for hello response...");
461
462 let mut reader = BufReader::new(&self.socket);
463 let mut line = String::new();
464
465 match reader.read_line(&mut line) {
466 Ok(n) => {
467 self.debug_message(&format!("Read {n} bytes: {line}"));
468
469 match serde_json::from_str::<ModelInfo>(&line) {
470 Ok(info) => {
471 self.debug_message("Successfully parsed model info");
472 if !info.success {
473 self.debug_message("Model initialization failed");
474 return Err(EimError::ExecutionError(
475 "Model initialization failed".to_string(),
476 ));
477 }
478 self.debug_message("Got model info response, storing it");
479 self.model_info = Some(info);
480 return Ok(());
481 }
482 Err(e) => {
483 self.debug_message(&format!("Failed to parse model info: {e}"));
484 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line)
485 && !error.success
486 {
487 self.debug_message(&format!("Got error response: {error:?}"));
488 return Err(EimError::ExecutionError(
489 error.error.unwrap_or_else(|| "Unknown error".to_string()),
490 ));
491 }
492 }
493 }
494 }
495 Err(e) => {
496 self.debug_message(&format!("Failed to read hello response: {e}"));
497 return Err(EimError::SocketError(format!(
498 "Failed to read response: {e}"
499 )));
500 }
501 }
502
503 self.debug_message("No valid hello response received");
504 Err(EimError::SocketError(
505 "No valid response received".to_string(),
506 ))
507 }
508
509 pub fn path(&self) -> &Path {
511 &self.path
512 }
513
514 pub fn socket_path(&self) -> &Path {
516 &self.socket_path
517 }
518
519 pub fn sensor_type(&self) -> Result<SensorType, EimError> {
521 self.model_info
522 .as_ref()
523 .map(|info| SensorType::from(info.model_parameters.sensor))
524 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
525 }
526
527 pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
529 self.model_info
530 .as_ref()
531 .map(|info| &info.model_parameters)
532 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
533 }
534
535 pub fn infer(
559 &mut self,
560 features: Vec<f32>,
561 debug: Option<bool>,
562 ) -> Result<InferenceResponse, EimError> {
563 if self.model_info.is_none() {
565 self.send_hello()?;
566 }
567
568 let uses_continuous_mode = self.requires_continuous_mode();
569
570 if uses_continuous_mode {
571 self.infer_continuous_internal(features, debug)
572 } else {
573 self.infer_single(features, debug)
574 }
575 }
576
577 fn infer_continuous_internal(
578 &mut self,
579 features: Vec<f32>,
580 debug: Option<bool>,
581 ) -> Result<InferenceResponse, EimError> {
582 if self.continuous_state.is_none() {
584 let labels = self
585 .model_info
586 .as_ref()
587 .map(|info| info.model_parameters.labels.clone())
588 .unwrap_or_default();
589 let slice_size = self.input_size()?;
590
591 self.continuous_state = Some(ContinuousState::new(labels, slice_size));
592 }
593
594 let mut state = self.continuous_state.take().unwrap();
596 state.update_features(&features);
597
598 let response = if !state.feature_buffer_full {
599 Ok(InferenceResponse {
601 success: true,
602 id: self.next_message_id(),
603 result: InferenceResult::Classification {
604 classification: HashMap::new(),
605 },
606 })
607 } else {
608 let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
610
611 if let InferenceResult::Classification {
613 ref mut classification,
614 } = response.result
615 {
616 state.apply_maf(classification);
617 }
618
619 Ok(response)
620 };
621
622 self.continuous_state = Some(state);
624
625 response
626 }
627
628 fn infer_single(
629 &mut self,
630 features: Vec<f32>,
631 debug: Option<bool>,
632 ) -> Result<InferenceResponse, EimError> {
633 if self.model_info.is_none() {
635 self.debug_message("No model info, sending hello message...");
636 self.send_hello()?;
637 self.debug_message("Hello handshake completed");
638 }
639
640 let msg = ClassifyMessage {
641 classify: features.clone(),
642 id: self.next_message_id(),
643 debug,
644 };
645
646 let msg_str = serde_json::to_string(&msg)?;
647 self.debug_message(&format!(
648 "Sending inference message with {} features",
649 features.len()
650 ));
651
652 writeln!(self.socket, "{msg_str}").map_err(|e| {
653 self.debug_message(&format!("Failed to send inference message: {e}"));
654 EimError::SocketError(format!("Failed to send inference message: {e}"))
655 })?;
656
657 self.socket.flush().map_err(|e| {
658 self.debug_message(&format!("Failed to flush inference message: {e}"));
659 EimError::SocketError(format!("Failed to flush socket: {e}"))
660 })?;
661
662 self.debug_message("Inference message sent, waiting for response...");
663
664 self.socket.set_nonblocking(true).map_err(|e| {
666 self.debug_message(&format!("Failed to set non-blocking mode: {e}"));
667 EimError::SocketError(format!("Failed to set non-blocking mode: {e}"))
668 })?;
669
670 let mut reader = BufReader::new(&self.socket);
671 let mut buffer = String::new();
672 let start = Instant::now();
673 let timeout = Duration::from_secs(5);
674
675 while start.elapsed() < timeout {
676 match reader.read_line(&mut buffer) {
677 Ok(0) => {
678 self.debug_message("EOF reached");
679 break;
680 }
681 Ok(n) => {
682 if !buffer.contains("features:") && !buffer.contains("Features (") {
684 self.debug_message(&format!("Read {n} bytes: {buffer}"));
685 }
686
687 if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer)
688 && response.success
689 {
690 self.debug_message("Got successful inference response");
691 let _ = self.socket.set_nonblocking(false);
693 return Ok(response);
694 }
695 buffer.clear();
696 }
697 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
698 std::thread::sleep(Duration::from_millis(10));
700 continue;
701 }
702 Err(e) => {
703 self.debug_message(&format!("Read error: {e}"));
704 let _ = self.socket.set_nonblocking(false);
706 return Err(EimError::SocketError(format!("Read error: {e}")));
707 }
708 }
709 }
710
711 let _ = self.socket.set_nonblocking(false);
713 self.debug_message("Timeout reached");
714
715 Err(EimError::ExecutionError(format!(
716 "No valid response received within {} seconds",
717 timeout.as_secs()
718 )))
719 }
720
721 fn requires_continuous_mode(&self) -> bool {
723 self.model_info
724 .as_ref()
725 .map(|info| info.model_parameters.use_continuous_mode)
726 .unwrap_or(false)
727 }
728
729 pub fn input_size(&self) -> Result<usize, EimError> {
741 self.model_info
742 .as_ref()
743 .map(|info| info.model_parameters.input_features_count as usize)
744 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
745 }
746
747 pub async fn set_learn_block_threshold(
762 &mut self,
763 threshold: ThresholdConfig,
764 ) -> Result<(), EimError> {
765 if self.model_info.is_none() {
767 self.debug_message("No model info available, sending hello message...");
768 self.send_hello()?;
769 }
770
771 if let Some(info) = &self.model_info {
773 self.debug_message(&format!(
774 "Current model type: {}",
775 info.model_parameters.model_type
776 ));
777 self.debug_message(&format!(
778 "Current model parameters: {:?}",
779 info.model_parameters
780 ));
781 }
782
783 let msg = SetThresholdMessage {
784 set_threshold: threshold,
785 id: self.next_message_id(),
786 };
787
788 let msg_str = serde_json::to_string(&msg)?;
789 self.debug_message(&format!("Sending threshold message: {msg_str}"));
790
791 writeln!(self.socket, "{msg_str}").map_err(|e| {
792 self.debug_message(&format!("Failed to send threshold message: {e}"));
793 EimError::SocketError(format!("Failed to send threshold message: {e}"))
794 })?;
795
796 self.socket.flush().map_err(|e| {
797 self.debug_message(&format!("Failed to flush threshold message: {e}"));
798 EimError::SocketError(format!("Failed to flush socket: {e}"))
799 })?;
800
801 let mut reader = BufReader::new(&self.socket);
802 let mut line = String::new();
803
804 match reader.read_line(&mut line) {
805 Ok(_) => {
806 self.debug_message(&format!("Received response: {line}"));
807 match serde_json::from_str::<SetThresholdResponse>(&line) {
808 Ok(response) => {
809 if response.success {
810 self.debug_message("Successfully set threshold");
811 Ok(())
812 } else {
813 self.debug_message("Server reported failure setting threshold");
814 Err(EimError::ExecutionError(
815 "Server reported failure setting threshold".to_string(),
816 ))
817 }
818 }
819 Err(e) => {
820 self.debug_message(&format!("Failed to parse threshold response: {e}"));
821 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
823 Err(EimError::ExecutionError(
824 error.error.unwrap_or_else(|| "Unknown error".to_string()),
825 ))
826 } else {
827 Err(EimError::ExecutionError(format!(
828 "Invalid threshold response format: {e}"
829 )))
830 }
831 }
832 }
833 }
834 Err(e) => {
835 self.debug_message(&format!("Failed to read threshold response: {e}"));
836 Err(EimError::SocketError(format!(
837 "Failed to read response: {e}"
838 )))
839 }
840 }
841 }
842
843 fn get_min_anomaly_score(&self) -> f32 {
845 self.model_info
846 .as_ref()
847 .and_then(|info| {
848 info.model_parameters
849 .thresholds
850 .iter()
851 .find_map(|t| match t {
852 ModelThreshold::AnomalyGMM {
853 min_anomaly_score, ..
854 } => Some(*min_anomaly_score),
855 _ => None,
856 })
857 })
858 .unwrap_or(6.0)
859 }
860
861 fn normalize_anomaly_score(&self, score: f32) -> f32 {
863 (score / self.get_min_anomaly_score()).min(1.0)
864 }
865
866 pub fn normalize_visual_anomaly(
868 &self,
869 anomaly: f32,
870 max: f32,
871 mean: f32,
872 regions: &[(f32, u32, u32, u32, u32)],
873 ) -> VisualAnomalyResult {
874 let normalized_anomaly = self.normalize_anomaly_score(anomaly);
875 let normalized_max = self.normalize_anomaly_score(max);
876 let normalized_mean = self.normalize_anomaly_score(mean);
877 let normalized_regions: Vec<_> = regions
878 .iter()
879 .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
880 .collect();
881
882 (
883 normalized_anomaly,
884 normalized_max,
885 normalized_mean,
886 normalized_regions,
887 )
888 }
889}