edge_impulse_runner/inference/
messages.rs

1//! Message types for Edge Impulse model communication.
2//!
3//! This module defines the message structures used for communication between
4//! the runner and the Edge Impulse model process via Unix sockets. All messages
5//! are serialized to JSON for transmission.
6//!
7//! The communication follows a request-response pattern with the following types:
8//! - Initialization messages (`HelloMessage`)
9//! - Classification requests (`ClassifyMessage`)
10//! - Model information responses (`ModelInfo`)
11//! - Inference results (`InferenceResponse`)
12//! - Error responses (`ErrorResponse`)
13//! - Threshold messages (`SetThresholdMessage`)
14
15use crate::types::*;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fmt;
19
20/// Initial handshake message sent to the model process.
21///
22/// This message is sent when establishing communication with the model to
23/// initialize the connection and receive model information.
24#[derive(Serialize, Debug)]
25pub struct HelloMessage {
26    /// Protocol version number
27    pub hello: u32,
28    /// Unique message identifier
29    pub id: u32,
30}
31
32/// Message containing features for classification.
33///
34/// Used to send preprocessed input features to the model for inference.
35#[derive(Serialize, Debug)]
36pub struct ClassifyMessage {
37    /// Vector of preprocessed features matching the model's input requirements
38    pub classify: Vec<f32>,
39    /// Unique message identifier
40    pub id: u32,
41    /// Optional flag to enable debug output for this classification
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub debug: Option<bool>,
44}
45
46/// Response containing model information and parameters.
47///
48/// Received after sending a `HelloMessage`, contains essential information
49/// about the model's configuration and capabilities.
50#[derive(Debug, Deserialize, Clone)]
51pub struct ModelInfo {
52    /// Indicates if the model initialization was successful
53    pub success: bool,
54    /// Message identifier matching the request
55    #[allow(dead_code)]
56    pub id: u32,
57    /// Model parameters including input size, type, and other configuration
58    pub model_parameters: ModelParameters,
59    /// Project information from Edge Impulse
60    #[allow(dead_code)]
61    pub project: ProjectInfo,
62}
63
64/// Represents different types of inference results.
65///
66/// Models can produce different types of outputs depending on their type:
67/// - Classification models return class probabilities
68/// - Object detection models return bounding boxes and optional classifications
69#[derive(Deserialize, Serialize, Debug)]
70#[serde(untagged)]
71pub enum InferenceResult {
72    /// Result from a classification model
73    Classification {
74        /// Map of class names to their probability scores
75        classification: HashMap<String, f32>,
76    },
77    /// Result from an object detection model
78    ObjectDetection {
79        /// Vector of detected objects with their bounding boxes
80        bounding_boxes: Vec<BoundingBox>,
81        /// Optional classification results for the entire image
82        #[serde(default)]
83        classification: HashMap<String, f32>,
84    },
85    /// Result from a visual anomaly detection model
86    VisualAnomaly {
87        /// Grid of anomaly scores for different regions of the image
88        visual_anomaly_grid: Vec<BoundingBox>,
89        /// Maximum anomaly score across all regions
90        visual_anomaly_max: f32,
91        /// Mean anomaly score across all regions
92        visual_anomaly_mean: f32,
93        /// Overall anomaly score for the image
94        anomaly: f32,
95    },
96}
97
98/// Response containing inference results.
99///
100/// Received after sending a `ClassifyMessage`, contains the model's
101/// predictions and confidence scores.
102#[derive(Deserialize, Debug)]
103pub struct InferenceResponse {
104    /// Indicates if the inference was successful
105    pub success: bool,
106    /// Message identifier matching the request
107    pub id: u32,
108    /// The actual inference results
109    pub result: InferenceResult,
110}
111
112impl fmt::Display for InferenceResponse {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        match &self.result {
115            InferenceResult::Classification { classification } => {
116                write!(f, "Classification results: ")?;
117                for (class, probability) in classification {
118                    write!(f, "{}={:.2}% ", class, probability * 100.0)?;
119                }
120                Ok(())
121            }
122            InferenceResult::ObjectDetection {
123                bounding_boxes,
124                classification,
125            } => {
126                if !classification.is_empty() {
127                    write!(f, "Image classification: ")?;
128                    for (class, probability) in classification {
129                        write!(f, "{}={:.2}% ", class, probability * 100.0)?;
130                    }
131                    writeln!(f)?;
132                }
133                write!(f, "Detected objects: ")?;
134                for bbox in bounding_boxes {
135                    write!(
136                        f,
137                        "{}({:.2}%) at ({},{},{},{}) ",
138                        bbox.label,
139                        bbox.value * 100.0,
140                        bbox.x,
141                        bbox.y,
142                        bbox.width,
143                        bbox.height
144                    )?;
145                }
146                Ok(())
147            }
148            InferenceResult::VisualAnomaly {
149                visual_anomaly_grid,
150                visual_anomaly_max,
151                visual_anomaly_mean,
152                anomaly,
153            } => {
154                write!(
155                    f,
156                    "Visual anomaly detection: max={:.2}%, mean={:.2}%, overall={:.2}%",
157                    visual_anomaly_max * 100.0,
158                    visual_anomaly_mean * 100.0,
159                    anomaly * 100.0
160                )?;
161                if !visual_anomaly_grid.is_empty() {
162                    writeln!(f)?;
163                    write!(f, "Anomaly grid: ")?;
164                    for bbox in visual_anomaly_grid {
165                        write!(
166                            f,
167                            "{}({:.2}%) at ({},{},{},{}) ",
168                            bbox.label,
169                            bbox.value * 100.0,
170                            bbox.x,
171                            bbox.y,
172                            bbox.width,
173                            bbox.height
174                        )?;
175                    }
176                }
177                Ok(())
178            }
179        }
180    }
181}
182
183/// Response indicating an error condition.
184///
185/// Received when an error occurs during model communication or inference.
186#[derive(Deserialize, Debug)]
187pub struct ErrorResponse {
188    /// Always false for error responses
189    pub success: bool,
190    /// Optional error message describing what went wrong
191    #[serde(default)]
192    pub error: Option<String>,
193    /// Message identifier matching the request, if available
194    #[allow(dead_code)]
195    #[serde(default)]
196    pub id: Option<u32>,
197}
198
199/// Message for setting model thresholds
200#[derive(Debug, Serialize)]
201pub struct SetThresholdMessage {
202    /// The threshold configuration to set
203    pub set_threshold: ThresholdConfig,
204    /// Unique message identifier
205    pub id: u32,
206}
207
208/// Different types of threshold configurations that can be set
209#[derive(Debug, Deserialize)]
210pub struct SetThresholdResponse {
211    /// Indicates if the threshold was set successfully
212    pub success: bool,
213    /// Message identifier matching the request
214    pub id: u32,
215}
216
217#[derive(Debug, Serialize)]
218#[serde(tag = "type")]
219pub enum ThresholdConfig {
220    #[serde(rename = "object_detection")]
221    ObjectDetection { id: u32, min_score: f32 },
222    #[serde(rename = "anomaly")]
223    AnomalyGMM { id: u32, min_anomaly_score: f32 },
224}