edge_impulse_runner/inference/messages.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
//! Message types for Edge Impulse model communication.
//!
//! This module defines the message structures used for communication between
//! the runner and the Edge Impulse model process via Unix sockets. All messages
//! are serialized to JSON for transmission.
//!
//! The communication follows a request-response pattern with the following types:
//! - Initialization messages (`HelloMessage`)
//! - Classification requests (`ClassifyMessage`)
//! - Model information responses (`ModelInfo`)
//! - Inference results (`InferenceResponse`)
//! - Error responses (`ErrorResponse`)
//! - Threshold messages (`SetThresholdMessage`)
use crate::types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
/// Initial handshake message sent to the model process.
///
/// This message is sent when establishing communication with the model to
/// initialize the connection and receive model information.
#[derive(Serialize, Debug)]
pub struct HelloMessage {
/// Protocol version number
pub hello: u32,
/// Unique message identifier
pub id: u32,
}
/// Message containing features for classification.
///
/// Used to send preprocessed input features to the model for inference.
#[derive(Serialize, Debug)]
pub struct ClassifyMessage {
/// Vector of preprocessed features matching the model's input requirements
pub classify: Vec<f32>,
/// Unique message identifier
pub id: u32,
/// Optional flag to enable debug output for this classification
#[serde(skip_serializing_if = "Option::is_none")]
pub debug: Option<bool>,
}
/// Response containing model information and parameters.
///
/// Received after sending a `HelloMessage`, contains essential information
/// about the model's configuration and capabilities.
#[derive(Debug, Deserialize)]
pub struct ModelInfo {
/// Indicates if the model initialization was successful
pub success: bool,
/// Message identifier matching the request
#[allow(dead_code)]
pub id: u32,
/// Model parameters including input size, type, and other configuration
pub model_parameters: ModelParameters,
/// Project information from Edge Impulse
#[allow(dead_code)]
pub project: ProjectInfo,
}
/// Represents different types of inference results.
///
/// Models can produce different types of outputs depending on their type:
/// - Classification models return class probabilities
/// - Object detection models return bounding boxes and optional classifications
#[derive(Deserialize, Serialize, Debug)]
#[serde(untagged)]
pub enum InferenceResult {
/// Result from a classification model
Classification {
/// Map of class names to their probability scores
classification: HashMap<String, f32>,
},
/// Result from an object detection model
ObjectDetection {
/// Vector of detected objects with their bounding boxes
bounding_boxes: Vec<BoundingBox>,
/// Optional classification results for the entire image
#[serde(default)]
classification: HashMap<String, f32>,
},
/// Result from a visual anomaly detection model
VisualAnomaly {
/// Grid of anomaly scores for different regions of the image
visual_anomaly_grid: Vec<BoundingBox>,
/// Maximum anomaly score across all regions
visual_anomaly_max: f32,
/// Mean anomaly score across all regions
visual_anomaly_mean: f32,
/// Overall anomaly score for the image
anomaly: f32,
},
}
/// Response containing inference results.
///
/// Received after sending a `ClassifyMessage`, contains the model's
/// predictions and confidence scores.
#[derive(Deserialize, Debug)]
pub struct InferenceResponse {
/// Indicates if the inference was successful
pub success: bool,
/// Message identifier matching the request
pub id: u32,
/// The actual inference results
pub result: InferenceResult,
}
impl fmt::Display for InferenceResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.result {
InferenceResult::Classification { classification } => {
write!(f, "Classification results: ")?;
for (class, probability) in classification {
write!(f, "{}={:.2}% ", class, probability * 100.0)?;
}
Ok(())
}
InferenceResult::ObjectDetection {
bounding_boxes,
classification,
} => {
if !classification.is_empty() {
write!(f, "Image classification: ")?;
for (class, probability) in classification {
write!(f, "{}={:.2}% ", class, probability * 100.0)?;
}
writeln!(f)?;
}
write!(f, "Detected objects: ")?;
for bbox in bounding_boxes {
write!(
f,
"{}({:.2}%) at ({},{},{},{}) ",
bbox.label,
bbox.value * 100.0,
bbox.x,
bbox.y,
bbox.width,
bbox.height
)?;
}
Ok(())
}
InferenceResult::VisualAnomaly {
visual_anomaly_grid,
visual_anomaly_max,
visual_anomaly_mean,
anomaly,
} => {
write!(
f,
"Visual anomaly detection: max={:.2}%, mean={:.2}%, overall={:.2}%",
visual_anomaly_max * 100.0,
visual_anomaly_mean * 100.0,
anomaly * 100.0
)?;
if !visual_anomaly_grid.is_empty() {
writeln!(f)?;
write!(f, "Anomaly grid: ")?;
for bbox in visual_anomaly_grid {
write!(
f,
"{}({:.2}%) at ({},{},{},{}) ",
bbox.label,
bbox.value * 100.0,
bbox.x,
bbox.y,
bbox.width,
bbox.height
)?;
}
}
Ok(())
}
}
}
}
/// Response indicating an error condition.
///
/// Received when an error occurs during model communication or inference.
#[derive(Deserialize, Debug)]
pub struct ErrorResponse {
/// Always false for error responses
pub success: bool,
/// Optional error message describing what went wrong
#[serde(default)]
pub error: Option<String>,
/// Message identifier matching the request, if available
#[allow(dead_code)]
#[serde(default)]
pub id: Option<u32>,
}
/// Message for setting model thresholds
#[derive(Debug, Serialize)]
pub struct SetThresholdMessage {
/// The threshold configuration to set
pub set_threshold: ThresholdConfig,
/// Unique message identifier
pub id: u32,
}
/// Different types of threshold configurations that can be set
#[derive(Debug, Deserialize)]
pub struct SetThresholdResponse {
/// Indicates if the threshold was set successfully
pub success: bool,
/// Message identifier matching the request
pub id: u32,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub enum ThresholdConfig {
#[serde(rename = "object_detection")]
ObjectDetection { id: u32, min_score: f32 },
#[serde(rename = "anomaly")]
AnomalyGMM { id: u32, min_anomaly_score: f32 },
}