edge_impulse_runner/inference/
mod.rs

1pub mod messages;
2mod model;
3
4pub use model::EimModel;
5
6#[cfg(test)]
7mod tests {
8    use crate::{EimError, EimModel};
9    use std::env;
10    use std::fs::File;
11    use std::io::Write;
12    use std::path::Path;
13    use std::process::Command;
14    use tempfile;
15
16    /// Creates a mock EIM executable for testing
17    ///
18    /// This function creates a shell script that simulates an EIM model by:
19    /// 1. Accepting a socket path argument
20    /// 2. Creating a Unix socket at that path using socat
21    /// 3. Responding to the hello message with a valid JSON response
22    fn create_mock_eim() -> std::path::PathBuf {
23        let manifest_dir =
24            env::var("CARGO_MANIFEST_DIR").expect("Failed to get manifest directory");
25        let mock_path = Path::new(&manifest_dir).join("mock_eim.sh");
26        let response_path = Path::new(&manifest_dir).join("mock_response.json");
27
28        // Create the response JSON file
29        let response_json = r#"{"success":true,"id":1,"model_parameters":{"axis_count":3,"frequency":62.5,"has_anomaly":1,"image_channel_count":0,"image_input_frames":0,"image_input_height":0,"image_input_width":0,"image_resize_mode":"none","inferencing_engine":4,"input_features_count":375,"interval_ms":16,"label_count":6,"labels":["drink","fistbump","idle","snake","updown","wave"],"model_type":"classification","sensor":2,"slice_size":31,"threshold":0.6,"use_continuous_mode":false},"project":{"deploy_version":271,"id":1,"name":"Test Project","owner":"Test Owner"}}"#;
30        std::fs::write(&response_path, response_json).unwrap();
31
32        // Create the mock script that reads from the response file
33        let mock_script = format!(
34            r#"#!/bin/sh
35SOCKET_PATH=$1
36socat UNIX-LISTEN:$SOCKET_PATH,fork SYSTEM:'cat {}'"#,
37            response_path.display()
38        );
39
40        let mut file = File::create(&mock_path).unwrap();
41        file.write_all(mock_script.as_bytes()).unwrap();
42
43        // Make the script executable
44        use std::os::unix::fs::PermissionsExt;
45        let mut perms = std::fs::metadata(&mock_path).unwrap().permissions();
46        perms.set_mode(0o755);
47        std::fs::set_permissions(&mock_path, perms).unwrap();
48
49        mock_path
50    }
51
52    #[test]
53    fn test_missing_file_error() {
54        // Create a temporary directory for the socket
55        let temp_dir = tempfile::tempdir().unwrap();
56        let socket_path = temp_dir.path().join("test.socket");
57
58        // Test with a non-existent file
59        let result = EimModel::new_with_socket("unknown.eim", &socket_path);
60        match result {
61            Err(EimError::ExecutionError(msg)) if msg.contains("No such file") => (),
62            other => panic!("Expected ExecutionError for missing file, got {:?}", other),
63        }
64    }
65
66    #[test]
67    fn test_invalid_extension() {
68        // Verify that attempting to load a file without .eim extension returns InvalidPath
69        let temp_file = std::env::temp_dir().join("test.txt");
70        std::fs::write(&temp_file, "dummy content").unwrap();
71
72        let result = EimModel::new(&temp_file);
73        match result {
74            Err(EimError::InvalidPath) => (),
75            _ => panic!("Expected InvalidPath when file has wrong extension"),
76        }
77    }
78
79    #[test]
80    fn test_successful_connection() {
81        // Check if socat is available (required for this test)
82        let socat_check = Command::new("which")
83            .arg("socat")
84            .output()
85            .expect("Failed to check for socat");
86
87        if !socat_check.status.success() {
88            println!("Skipping test: socat is not installed");
89            return;
90        }
91
92        // Create a temporary directory for the socket
93        let temp_dir = tempfile::tempdir().unwrap();
94        let socket_path = temp_dir.path().join("test.socket");
95
96        // Create and set up the mock EIM executable
97        let mock_path = create_mock_eim();
98        let response_path = mock_path.with_extension("json");
99        let mut mock_path_with_eim = mock_path.clone();
100        mock_path_with_eim.set_extension("eim");
101        std::fs::rename(&mock_path, &mock_path_with_eim).unwrap();
102
103        // Test the connection with the custom socket path
104        let result = EimModel::new_with_socket(&mock_path_with_eim, &socket_path);
105        assert!(
106            result.is_ok(),
107            "Failed to create EIM model: {:?}",
108            result.err()
109        );
110
111        // Clean up the test files
112        if mock_path_with_eim.exists() {
113            std::fs::remove_file(&mock_path_with_eim).unwrap_or_else(|e| {
114                println!("Warning: Failed to remove mock EIM file: {}", e);
115            });
116        }
117        if response_path.exists() {
118            std::fs::remove_file(&response_path).unwrap_or_else(|e| {
119                println!("Warning: Failed to remove response file: {}", e);
120            });
121        }
122    }
123
124    #[test]
125    fn test_connection_timeout() {
126        // Create a temporary directory
127        let temp_dir = tempfile::tempdir().unwrap();
128        let socket_path = temp_dir.path().join("test.socket");
129        let model_path = temp_dir.path().join("dummy.eim");
130
131        // Create the executable
132        let script = "#!/bin/sh\nsleep 10\n"; // Sleep long enough for timeout
133        std::fs::write(&model_path, script).unwrap();
134
135        #[cfg(unix)]
136        {
137            use std::os::unix::fs::PermissionsExt;
138            let mut perms = std::fs::metadata(&model_path).unwrap().permissions();
139            perms.set_mode(0o755);
140            std::fs::set_permissions(&model_path, perms).unwrap();
141        }
142
143        // Test that we get the expected timeout error
144        let result = EimModel::new_with_socket(&model_path, &socket_path);
145        assert!(
146            matches!(result,
147                Err(EimError::SocketError(ref msg)) if msg.contains("Timeout waiting for socket")
148            ),
149            "Expected timeout error, got {:?}",
150            result
151        );
152    }
153}