edge_impulse_runner/ingestion/
mod.rs

1use crate::error::IngestionError;
2use hmac::{Hmac, Mac};
3use mime_guess::from_path;
4use serde::{Deserialize, Serialize};
5use sha2::Sha256;
6use std::path::Path;
7use std::time::{SystemTime, UNIX_EPOCH};
8use tracing::debug;
9const DEFAULT_INGESTION_HOST: &str = "https://ingestion.edgeimpulse.com";
10
11/// Edge Impulse Ingestion API client
12///
13/// This module provides a client implementation for the Edge Impulse Ingestion API, which allows
14/// uploading data samples and files to Edge Impulse for machine learning training, testing, and
15/// anomaly detection.
16///
17/// # API Endpoints
18///
19/// The client supports two types of endpoints:
20///
21/// * Data endpoints (legacy):
22///   - `/api/training/data`
23///   - `/api/testing/data`
24///   - `/api/anomaly/data`
25///
26/// * File endpoints:
27///   - `/api/training/files`
28///   - `/api/testing/files`
29///   - `/api/anomaly/files`
30///
31/// # Examples
32///
33/// ```no_run
34/// use edge_impulse_runner::ingestion::{Ingestion, Category, Sensor, UploadSampleParams};
35///
36/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
37/// // Create a new client
38/// let client = Ingestion::new("your-api-key".to_string())
39///     .with_hmac("optional-hmac-key".to_string());
40///
41/// // Upload a file
42/// let response = client.upload_file(
43///     "data.wav",
44///     Category::Training,
45///     Some("walking".to_string()),
46///     None
47/// ).await?;
48///
49/// // Upload sensor data
50/// let sensors = vec![
51///     Sensor {
52///         name: "accX".to_string(),
53///         units: "m/s2".to_string(),
54///     }
55/// ];
56/// let values = vec![vec![1.0, 2.0, 3.0]];
57///
58/// let params = UploadSampleParams {
59///     device_id: "device-id",
60///     device_type: "CUSTOM_DEVICE",
61///     sensors,
62///     values,
63///     interval_ms: 100.0,
64///     label: Some("walking".to_string()),
65///     category: "training",
66/// };
67///
68/// let response = client.upload_sample(params).await?;
69/// # Ok(())
70/// # }
71/// ```
72
73#[derive(Debug, Serialize, Deserialize)]
74struct Protected {
75    ver: String,
76    alg: String,
77    iat: u64,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
81struct Payload {
82    device_name: String,
83    device_type: String,
84    interval_ms: f64,
85    sensors: Vec<Sensor>,
86    values: Vec<Vec<f64>>,
87}
88
89/// Represents a sensor in the Edge Impulse data format
90///
91/// Each sensor has a name and units that describe the type of data being collected.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Sensor {
94    pub name: String,
95    pub units: String,
96}
97
98#[derive(Debug, Serialize, Deserialize)]
99struct DataMessage {
100    protected: Protected,
101    signature: String,
102    payload: Payload,
103}
104
105/// Data category for Edge Impulse uploads
106///
107/// Determines which API endpoint will be used for the upload:
108/// - Training: Used for gathering training data
109/// - Testing: Used for gathering testing data
110/// - Anomaly: Used for anomaly detection data
111#[derive(Debug, Clone, Copy)]
112pub enum Category {
113    Training,
114    Testing,
115    Anomaly,
116}
117
118impl Category {
119    fn as_str(&self) -> &'static str {
120        match self {
121            Category::Training => "training",
122            Category::Testing => "testing",
123            Category::Anomaly => "anomaly",
124        }
125    }
126}
127
128/// Parameters for uploading a sample
129#[derive(Debug)]
130pub struct UploadSampleParams<'a> {
131    /// Device identifier
132    pub device_id: &'a str,
133    /// Type of device
134    pub device_type: &'a str,
135    /// List of sensors
136    pub sensors: Vec<Sensor>,
137    /// Sample values
138    pub values: Vec<Vec<f64>>,
139    /// Interval in milliseconds
140    pub interval_ms: f64,
141    /// Optional label for the sample
142    pub label: Option<String>,
143    /// Category (training, testing, or anomaly)
144    pub category: &'a str,
145}
146
147/// Edge Impulse Ingestion API client
148///
149/// This struct provides methods to interact with the Edge Impulse Ingestion API.
150/// It supports uploading both raw sensor data and files to Edge Impulse for
151/// machine learning training, testing, and anomaly detection.
152pub struct Ingestion {
153    api_key: String,
154    hmac_key: Option<String>,
155    host: String,
156    debug: bool,
157}
158
159impl Ingestion {
160    pub fn new(api_key: String) -> Self {
161        Self {
162            api_key,
163            hmac_key: None,
164            host: DEFAULT_INGESTION_HOST.to_string(),
165            debug: false,
166        }
167    }
168
169    pub fn with_host(api_key: String, host: String) -> Self {
170        Self {
171            api_key,
172            hmac_key: None,
173            host,
174            debug: false,
175        }
176    }
177
178    pub fn with_hmac(mut self, hmac_key: String) -> Self {
179        self.hmac_key = Some(hmac_key);
180        self
181    }
182
183    pub fn with_debug(mut self) -> Self {
184        self.debug = true;
185        self
186    }
187
188    async fn create_signature(&self, data: &[u8]) -> Result<String, IngestionError> {
189        if let Some(hmac_key) = &self.hmac_key {
190            let mut mac = Hmac::<Sha256>::new_from_slice(hmac_key.as_bytes())
191                .map_err(|e| IngestionError::Config(e.to_string()))?;
192            mac.update(data);
193            let result = mac.finalize();
194            Ok(hex::encode(result.into_bytes()))
195        } else {
196            Ok("0".repeat(64))
197        }
198    }
199
200    pub async fn upload_sample(
201        &self,
202        params: UploadSampleParams<'_>,
203    ) -> Result<String, IngestionError> {
204        if self.debug {
205            println!("=== Request Details ===");
206            println!("URL: {}/api/{}/data", self.host, params.category);
207            println!("Device ID: {}", params.device_id);
208            println!("Device Type: {}", params.device_type);
209            println!("Sensors: {:?}", params.sensors);
210            println!(
211                "Data size: {} sensors, {} samples",
212                params.sensors.len(),
213                params.values.len()
214            );
215        }
216
217        debug!("Creating data message");
218        let payload = Payload {
219            device_name: params.device_id.to_string(),
220            device_type: params.device_type.to_string(),
221            interval_ms: params.interval_ms,
222            sensors: params.sensors.clone(),
223            values: params.values.iter().map(|v| v.to_vec()).collect(),
224        };
225
226        let message = DataMessage {
227            protected: Protected {
228                ver: "v1".to_string(),
229                alg: "HS256".to_string(),
230                iat: SystemTime::now()
231                    .duration_since(UNIX_EPOCH)
232                    .unwrap()
233                    .as_secs(),
234            },
235            signature: "0".repeat(64),
236            payload,
237        };
238
239        debug!("Serializing data message");
240        let json = serde_json::to_string(&message)?;
241
242        if let Some(ref _hmac_key) = self.hmac_key {
243            debug!("Creating signature for data");
244            let signature = self.create_signature(json.as_bytes()).await?;
245            debug!("Generated signature: {}", signature);
246        }
247
248        debug!("Creating multipart form");
249        let form = reqwest::multipart::Form::new().text("data", json);
250
251        let mut headers = reqwest::header::HeaderMap::new();
252        debug!("Setting up headers");
253        headers.insert("x-api-key", self.api_key.parse()?);
254        headers.insert("x-file-name", format!("{}.json", params.device_id).parse()?);
255
256        if let Some(label) = params.label {
257            debug!("Adding label header: {}", label);
258            headers.insert("x-label", urlencoding::encode(&label).parse()?);
259        }
260
261        if self.debug {
262            println!("=== Request Headers ===");
263            println!("{:#?}", &headers);
264        }
265
266        let client = reqwest::Client::new();
267        let response = client
268            .post(format!("{}/api/{}/data", self.host, params.category))
269            .headers(headers.clone())
270            .multipart(form)
271            .send()
272            .await?;
273
274        let status = response.status();
275
276        if self.debug {
277            println!("=== Response ===");
278            println!("Status: {status}");
279            println!("Headers: {:#?}", response.headers());
280        }
281
282        let body = response.text().await?;
283
284        if self.debug {
285            println!("Body: {body}");
286        }
287
288        if !status.is_success() {
289            return Err(IngestionError::Server {
290                status_code: status.as_u16(),
291                message: body,
292            });
293        }
294
295        Ok(body)
296    }
297
298    /// Upload a file to Edge Impulse using the /files endpoint
299    pub async fn upload_file<P: AsRef<Path>>(
300        &self,
301        file_path: P,
302        category: Category,
303        label: Option<String>,
304        options: Option<UploadOptions>,
305    ) -> Result<String, IngestionError> {
306        let path = file_path.as_ref();
307
308        // Verify the file exists
309        if !path.exists() {
310            return Err(IngestionError::Io(std::io::Error::new(
311                std::io::ErrorKind::NotFound,
312                format!("File not found: {path:?}"),
313            )));
314        }
315
316        // Get the mime type of the file
317        let mime_type = from_path(path).first_or_octet_stream().to_string();
318
319        if self.debug {
320            println!("Detected mime type: {mime_type}");
321        }
322
323        // Read the file
324        let file_data = std::fs::read(path)?;
325
326        // Create the multipart form
327        let form = reqwest::multipart::Form::new().part(
328            "data",
329            reqwest::multipart::Part::bytes(file_data)
330                .file_name(
331                    path.file_name()
332                        .and_then(|n| n.to_str())
333                        .unwrap_or("file")
334                        .to_string(),
335                )
336                .mime_str(&mime_type)?,
337        );
338
339        let mut headers = reqwest::header::HeaderMap::new();
340        headers.insert("x-api-key", self.api_key.parse()?);
341
342        if let Some(label) = label {
343            headers.insert("x-label", urlencoding::encode(&label).parse()?);
344        }
345
346        // Add optional headers from UploadOptions
347        if let Some(opts) = options {
348            if opts.disallow_duplicates {
349                headers.insert("x-disallow-duplicates", "1".parse()?);
350            }
351            if opts.add_date_id {
352                headers.insert("x-add-date-id", "1".parse()?);
353            }
354        }
355
356        if self.debug {
357            println!("=== Request Headers ===");
358            println!("{:#?}", &headers);
359        }
360
361        let client = reqwest::Client::new();
362        let response = client
363            .post(format!("{}/api/{}/files", self.host, category.as_str()))
364            .headers(headers.clone())
365            .multipart(form)
366            .send()
367            .await?;
368
369        let status = response.status();
370
371        if self.debug {
372            println!("=== Response ===");
373            println!("Status: {status}");
374            println!("Headers: {:#?}", response.headers());
375        }
376
377        let body = response.text().await?;
378
379        if self.debug {
380            println!("Body: {body}");
381        }
382
383        if !status.is_success() {
384            return Err(IngestionError::Server {
385                status_code: status.as_u16(),
386                message: body,
387            });
388        }
389
390        Ok(body)
391    }
392}
393
394/// Options for file uploads to Edge Impulse
395///
396/// These options correspond to various headers that can be set when uploading files.
397#[derive(Debug, Default)]
398pub struct UploadOptions {
399    /// When set, the server checks the hash of the message against your current dataset
400    pub disallow_duplicates: bool,
401    /// Add a date ID to the filename
402    pub add_date_id: bool,
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use mockito::Server;
409    use tracing::error;
410    use tracing_test::traced_test;
411
412    fn create_test_sensors() -> Vec<Sensor> {
413        vec![Sensor {
414            name: "accelerometer".to_string(),
415            units: "m/s2".to_string(),
416        }]
417    }
418
419    fn create_test_values() -> Vec<Vec<f64>> {
420        vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]
421    }
422
423    #[test]
424    #[traced_test]
425    fn test_ingestion_creation() {
426        let ingestion = Ingestion::new("test_key".to_string());
427        assert_eq!(ingestion.api_key, "test_key");
428        assert_eq!(ingestion.host, DEFAULT_INGESTION_HOST);
429        assert!(ingestion.hmac_key.is_none());
430
431        let ingestion_with_host =
432            Ingestion::with_host("test_key".to_string(), "http://custom.host".to_string());
433        assert_eq!(ingestion_with_host.host, "http://custom.host");
434
435        let ingestion_with_hmac =
436            Ingestion::new("test_key".to_string()).with_hmac("hmac_key".to_string());
437        assert!(ingestion_with_hmac.hmac_key.is_some());
438        assert_eq!(ingestion_with_hmac.hmac_key.unwrap(), "hmac_key");
439    }
440
441    #[test]
442    fn test_successful_upload() {
443        let mut server = Server::new();
444
445        let mock = server
446            .mock("POST", "/api/training/data")
447            .with_header("x-api-key", "test_key")
448            .with_header("x-file-name", "test_device.json")
449            .with_header("content-type", "multipart/form-data")
450            .with_status(200)
451            .with_body("OK")
452            .create();
453
454        let rt = tokio::runtime::Runtime::new().unwrap();
455
456        rt.block_on(async {
457            let ingestion = Ingestion::with_host("test_key".to_string(), server.url());
458
459            let params = UploadSampleParams {
460                device_id: "test_device",
461                device_type: "CUSTOM_DEVICE",
462                sensors: create_test_sensors(),
463                values: create_test_values(),
464                interval_ms: 100.0,
465                label: Some("walking".to_string()),
466                category: "training",
467            };
468
469            let result = ingestion.upload_sample(params).await;
470
471            assert!(result.is_ok());
472            assert_eq!(result.unwrap(), "OK");
473        });
474
475        mock.assert();
476    }
477
478    #[test]
479    #[traced_test]
480    fn test_upload_with_hmac() {
481        let mut server = Server::new();
482        debug!("Mock server created at: {}", server.url());
483
484        let mock = server
485            .mock("POST", "/api/training/data")
486            .match_header("x-api-key", "test_key")
487            .match_header("x-file-name", "test_device.json")
488            .match_header(
489                "content-type",
490                mockito::Matcher::Regex("multipart/form-data.*".to_string()),
491            )
492            .with_status(200)
493            .with_body("OK")
494            .expect(1)
495            .create();
496        debug!("Mock endpoint created");
497
498        let rt = tokio::runtime::Runtime::new().unwrap();
499
500        rt.block_on(async {
501            let ingestion = Ingestion::with_host("test_key".to_string(), server.url())
502                .with_hmac("test_hmac".to_string());
503            debug!("Created ingestion client with HMAC");
504
505            let sensors = create_test_sensors();
506            let values = create_test_values();
507            debug!(
508                "Test data created: sensors={:?}, values={:?}",
509                sensors, values
510            );
511
512            let params = UploadSampleParams {
513                device_id: "test_device",
514                device_type: "CUSTOM_DEVICE",
515                sensors,
516                values: values,
517                interval_ms: 100.0,
518                label: None,
519                category: "training",
520            };
521
522            let result = ingestion.upload_sample(params).await;
523
524            match &result {
525                Ok(response) => debug!("Upload successful: {}", response),
526                Err(e) => error!("Upload failed: {:?}", e),
527            }
528
529            assert!(result.is_ok(), "Upload failed: {:?}", result.err().unwrap());
530
531            mock.assert_async().await;
532        });
533
534        debug!("Test completed");
535    }
536
537    #[test]
538    fn test_upload_error() {
539        let mut server = Server::new();
540
541        let mock = server
542            .mock("POST", "/api/training/data")
543            .with_status(400)
544            .with_body("Invalid data")
545            .create();
546
547        let rt = tokio::runtime::Runtime::new().unwrap();
548
549        rt.block_on(async {
550            let ingestion = Ingestion::with_host("test_key".to_string(), server.url());
551
552            let params = UploadSampleParams {
553                device_id: "test_device",
554                device_type: "CUSTOM_DEVICE",
555                sensors: create_test_sensors(),
556                values: create_test_values(),
557                interval_ms: 100.0,
558                label: None,
559                category: "training",
560            };
561
562            let result = ingestion.upload_sample(params).await;
563
564            assert!(result.is_err());
565            match result {
566                Err(IngestionError::Server {
567                    status_code,
568                    message,
569                }) => {
570                    assert_eq!(status_code, 400);
571                    assert_eq!(message, "Invalid data");
572                }
573                _ => panic!("Expected Server error"),
574            }
575        });
576
577        mock.assert();
578    }
579
580    #[test]
581    fn test_invalid_category() {
582        let server = Server::new();
583        let rt = tokio::runtime::Runtime::new().unwrap();
584
585        rt.block_on(async {
586            let ingestion = Ingestion::with_host("test_key".to_string(), server.url());
587
588            let params = UploadSampleParams {
589                device_id: "test_device",
590                device_type: "CUSTOM_DEVICE",
591                sensors: create_test_sensors(),
592                values: create_test_values(),
593                interval_ms: 100.0,
594                label: None,
595                category: "invalid_category",
596            };
597
598            let result = ingestion.upload_sample(params).await;
599
600            assert!(result.is_err());
601        });
602    }
603}