edge_impulse_runner/inference/messages.rs
1//! Message types for Edge Impulse model communication.
2//!
3//! This module defines the message structures used for communication between
4//! the runner and the Edge Impulse model process via Unix sockets. All messages
5//! are serialized to JSON for transmission.
6//!
7//! The communication follows a request-response pattern with the following types:
8//! - Initialization messages (`HelloMessage`)
9//! - Classification requests (`ClassifyMessage`)
10//! - Model information responses (`ModelInfo`)
11//! - Inference results (`InferenceResponse`)
12//! - Error responses (`ErrorResponse`)
13//! - Threshold messages (`SetThresholdMessage`)
14
15use crate::types::*;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fmt;
19
20/// Initial handshake message sent to the model process.
21///
22/// This message is sent when establishing communication with the model to
23/// initialize the connection and receive model information.
24#[derive(Serialize, Debug)]
25pub struct HelloMessage {
26 /// Protocol version number
27 pub hello: u32,
28 /// Unique message identifier
29 pub id: u32,
30}
31
32/// Message containing features for classification.
33///
34/// Used to send preprocessed input features to the model for inference.
35#[derive(Serialize, Debug)]
36pub struct ClassifyMessage {
37 /// Vector of preprocessed features matching the model's input requirements
38 pub classify: Vec<f32>,
39 /// Unique message identifier
40 pub id: u32,
41 /// Optional flag to enable debug output for this classification
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub debug: Option<bool>,
44}
45
46/// Response containing model information and parameters.
47///
48/// Received after sending a `HelloMessage`, contains essential information
49/// about the model's configuration and capabilities.
50#[derive(Debug, Deserialize, Clone)]
51pub struct ModelInfo {
52 /// Indicates if the model initialization was successful
53 pub success: bool,
54 /// Message identifier matching the request
55 #[allow(dead_code)]
56 pub id: u32,
57 /// Model parameters including input size, type, and other configuration
58 pub model_parameters: ModelParameters,
59 /// Project information from Edge Impulse
60 #[allow(dead_code)]
61 pub project: ProjectInfo,
62}
63
64/// Represents different types of inference results.
65///
66/// Models can produce different types of outputs depending on their type:
67/// - Classification models return class probabilities
68/// - Object detection models return bounding boxes and optional classifications
69#[derive(Deserialize, Serialize, Debug)]
70#[serde(untagged)]
71pub enum InferenceResult {
72 /// Result from a classification model
73 Classification {
74 /// Map of class names to their probability scores
75 classification: HashMap<String, f32>,
76 },
77 /// Result from an object detection model
78 ObjectDetection {
79 /// Vector of detected objects with their bounding boxes
80 bounding_boxes: Vec<BoundingBox>,
81 /// Optional classification results for the entire image
82 #[serde(default)]
83 classification: HashMap<String, f32>,
84 },
85 /// Result from a visual anomaly detection model
86 VisualAnomaly {
87 /// Grid of anomaly scores for different regions of the image
88 visual_anomaly_grid: Vec<BoundingBox>,
89 /// Maximum anomaly score across all regions
90 visual_anomaly_max: f32,
91 /// Mean anomaly score across all regions
92 visual_anomaly_mean: f32,
93 /// Overall anomaly score for the image
94 anomaly: f32,
95 },
96}
97
98/// Response containing inference results.
99///
100/// Received after sending a `ClassifyMessage`, contains the model's
101/// predictions and confidence scores.
102#[derive(Deserialize, Debug)]
103pub struct InferenceResponse {
104 /// Indicates if the inference was successful
105 pub success: bool,
106 /// Message identifier matching the request
107 pub id: u32,
108 /// The actual inference results
109 pub result: InferenceResult,
110}
111
112impl fmt::Display for InferenceResponse {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match &self.result {
115 InferenceResult::Classification { classification } => {
116 write!(f, "Classification results: ")?;
117 for (class, probability) in classification {
118 write!(f, "{}={:.2}% ", class, probability * 100.0)?;
119 }
120 Ok(())
121 }
122 InferenceResult::ObjectDetection {
123 bounding_boxes,
124 classification,
125 } => {
126 if !classification.is_empty() {
127 write!(f, "Image classification: ")?;
128 for (class, probability) in classification {
129 write!(f, "{}={:.2}% ", class, probability * 100.0)?;
130 }
131 writeln!(f)?;
132 }
133 write!(f, "Detected objects: ")?;
134 for bbox in bounding_boxes {
135 write!(
136 f,
137 "{}({:.2}%) at ({},{},{},{}) ",
138 bbox.label,
139 bbox.value * 100.0,
140 bbox.x,
141 bbox.y,
142 bbox.width,
143 bbox.height
144 )?;
145 }
146 Ok(())
147 }
148 InferenceResult::VisualAnomaly {
149 visual_anomaly_grid,
150 visual_anomaly_max,
151 visual_anomaly_mean,
152 anomaly,
153 } => {
154 write!(
155 f,
156 "Visual anomaly detection: max={:.2}%, mean={:.2}%, overall={:.2}%",
157 visual_anomaly_max * 100.0,
158 visual_anomaly_mean * 100.0,
159 anomaly * 100.0
160 )?;
161 if !visual_anomaly_grid.is_empty() {
162 writeln!(f)?;
163 write!(f, "Anomaly grid: ")?;
164 for bbox in visual_anomaly_grid {
165 write!(
166 f,
167 "{}({:.2}%) at ({},{},{},{}) ",
168 bbox.label,
169 bbox.value * 100.0,
170 bbox.x,
171 bbox.y,
172 bbox.width,
173 bbox.height
174 )?;
175 }
176 }
177 Ok(())
178 }
179 }
180 }
181}
182
183/// Response indicating an error condition.
184///
185/// Received when an error occurs during model communication or inference.
186#[derive(Deserialize, Debug)]
187pub struct ErrorResponse {
188 /// Always false for error responses
189 pub success: bool,
190 /// Optional error message describing what went wrong
191 #[serde(default)]
192 pub error: Option<String>,
193 /// Message identifier matching the request, if available
194 #[allow(dead_code)]
195 #[serde(default)]
196 pub id: Option<u32>,
197}
198
199/// Message for setting model thresholds
200#[derive(Debug, Serialize)]
201pub struct SetThresholdMessage {
202 /// The threshold configuration to set
203 pub set_threshold: ThresholdConfig,
204 /// Unique message identifier
205 pub id: u32,
206}
207
208/// Different types of threshold configurations that can be set
209#[derive(Debug, Deserialize)]
210pub struct SetThresholdResponse {
211 /// Indicates if the threshold was set successfully
212 pub success: bool,
213 /// Message identifier matching the request
214 pub id: u32,
215}
216
217#[derive(Debug, Serialize)]
218#[serde(tag = "type")]
219pub enum ThresholdConfig {
220 #[serde(rename = "object_detection")]
221 ObjectDetection { id: u32, min_score: f32 },
222 #[serde(rename = "anomaly")]
223 AnomalyGMM { id: u32, min_anomaly_score: f32 },
224}