hyperactor_telemetry/
in_memory_reader.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::Weak;
12use std::time::Duration;
13
14use opentelemetry_sdk::error::OTelSdkResult;
15use opentelemetry_sdk::metrics::InstrumentKind;
16use opentelemetry_sdk::metrics::ManualReader;
17use opentelemetry_sdk::metrics::Pipeline;
18use opentelemetry_sdk::metrics::SdkMeterProvider;
19use opentelemetry_sdk::metrics::Temporality;
20use opentelemetry_sdk::metrics::data::AggregatedMetrics;
21use opentelemetry_sdk::metrics::data::MetricData;
22use opentelemetry_sdk::metrics::data::ResourceMetrics;
23use opentelemetry_sdk::metrics::reader::MetricReader;
24
25// InMemoryReader that uses a shared ManualReader and implements MetricReader
26#[derive(Debug, Clone)]
27pub struct InMemoryReader {
28    manual_reader: Arc<ManualReader>,
29}
30
31impl InMemoryReader {
32    // Create a new InMemoryReader with a specific ManualReader
33    pub fn new(manual_reader: Arc<ManualReader>) -> Self {
34        Self { manual_reader }
35    }
36
37    // Get all counters from the shared ManualReader
38    pub fn get_all_counters(&self) -> HashMap<String, i64> {
39        let mut rm = ResourceMetrics::default();
40        let _ = self.manual_reader.collect(&mut rm);
41
42        // Extract counters directly from the collected metrics
43        let mut counters = HashMap::new();
44        for scope in rm.scope_metrics() {
45            for metric in scope.metrics() {
46                let data = metric.data();
47
48                if let AggregatedMetrics::U64(MetricData::Sum(sum_u64)) = data {
49                    for data_point in sum_u64.data_points() {
50                        let metric_name = metric.name().to_owned();
51                        counters.insert(metric_name, data_point.value() as i64);
52                    }
53                } else if let AggregatedMetrics::I64(MetricData::Sum(sum_i64)) = data {
54                    for data_point in sum_i64.data_points() {
55                        let metric_name = metric.name().to_owned();
56                        counters.insert(metric_name, data_point.value());
57                    }
58                }
59            }
60        }
61        counters
62    }
63}
64
65impl MetricReader for InMemoryReader {
66    fn register_pipeline(&self, pipeline: Weak<Pipeline>) {
67        self.manual_reader.register_pipeline(pipeline);
68    }
69
70    fn collect(&self, rm: &mut ResourceMetrics) -> OTelSdkResult {
71        self.manual_reader.collect(rm)
72    }
73
74    fn force_flush(&self) -> OTelSdkResult {
75        self.manual_reader.force_flush()
76    }
77
78    fn shutdown_with_timeout(&self, timeout: Duration) -> OTelSdkResult {
79        self.manual_reader.shutdown_with_timeout(timeout)
80    }
81
82    fn temporality(&self, kind: InstrumentKind) -> Temporality {
83        self.manual_reader.temporality(kind)
84    }
85
86    fn shutdown(&self) -> OTelSdkResult {
87        self.manual_reader.shutdown()
88    }
89}
90
91// RAII guard for in-memory metrics collection during testing
92//
93// Usage:
94//     let _guard = InMemoryMetrics::new();
95//
96//     // Your code that emits metrics
97//     my_counter.add(42, &[]);
98//
99//     // Check accumulated metrics
100//     let counters = _guard.get_counters();
101//     assert_eq!(counters.get("my_counter"), Some(&42));
102pub struct InMemoryMetrics {
103    in_memory_reader: InMemoryReader,
104    _provider: SdkMeterProvider,
105}
106
107impl InMemoryMetrics {
108    // Create a new InMemoryMetrics
109    //
110    // This will:
111    // 1. Create a ManualReader as shared state
112    // 2. Create an InMemoryReader that uses the shared ManualReader
113    // 3. Create a new SdkMeterProvider with the InMemoryReader
114    // 4. Set it as the global meter provider
115    //
116    // When the guard is dropped, the provider will be shut down.
117    pub fn new() -> Self {
118        // Create the manual reader with cumulative temporality - this state
119        // will only exists for the lifetime of the guard
120        let manual_reader = Arc::new(
121            ManualReader::builder()
122                .with_temporality(Temporality::Cumulative)
123                .build(),
124        );
125
126        // Create the in-memory reader using the shared manual reader
127        let in_memory_reader = InMemoryReader::new(Arc::clone(&manual_reader));
128
129        // Create a new provider with the in-memory reader
130        let provider = SdkMeterProvider::builder()
131            .with_reader(in_memory_reader)
132            .build();
133
134        // Set as global provider
135        opentelemetry::global::set_meter_provider(provider.clone());
136
137        Self {
138            in_memory_reader: InMemoryReader::new(Arc::clone(&manual_reader)),
139            _provider: provider,
140        }
141    }
142
143    // Get all counters accumulated since this guard was created
144    pub fn get_counters(&self) -> HashMap<String, i64> {
145        self.in_memory_reader.get_all_counters()
146    }
147
148    // Get the value of a specific counter by name
149    pub fn get_counter(&self, name: &str) -> Option<i64> {
150        self.get_counters().get(name).copied()
151    }
152
153    // Get a reference to the InMemoryReader for advanced usage
154    pub fn reader(&self) -> &InMemoryReader {
155        &self.in_memory_reader
156    }
157}
158
159impl Drop for InMemoryMetrics {
160    fn drop(&mut self) {
161        // Shutdown our provider
162        let _ = self._provider.shutdown();
163
164        // Reset to a no-op provider to prevent metrics from continuing
165        // to be collected by our in-memory reader after the guard is dropped
166        let noop_provider = SdkMeterProvider::builder().build();
167        opentelemetry::global::set_meter_provider(noop_provider);
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_in_memory_metrics_guard() {
177        // Use the RAII guard
178        let guard = InMemoryMetrics::new();
179
180        // Create and use counters
181        crate::declare_static_counter!(GUARD_TEST_COUNTER, "guard_test_counter");
182        GUARD_TEST_COUNTER.add(42, &[]);
183
184        // Check that we can read the counter value
185        let counters = guard.get_counters();
186        assert_eq!(counters.get("guard_test_counter"), Some(&42));
187
188        // Test the convenience method
189        assert_eq!(guard.get_counter("guard_test_counter"), Some(42));
190        assert_eq!(guard.get_counter("nonexistent_counter"), None);
191
192        // Guard will be dropped here, cleaning up automatically
193    }
194
195    #[test]
196    fn test_multiple_guards_sequential() {
197        // Test that multiple guards work correctly when used sequentially
198        {
199            let guard1 = InMemoryMetrics::new();
200            crate::declare_static_counter!(COUNTER_1, "counter_1");
201            COUNTER_1.add(10, &[]);
202            assert_eq!(guard1.get_counter("counter_1"), Some(10));
203        } // guard1 dropped here
204
205        {
206            let guard2 = InMemoryMetrics::new();
207            crate::declare_static_counter!(COUNTER_2, "counter_2");
208            COUNTER_2.add(20, &[]);
209            assert_eq!(guard2.get_counter("counter_2"), Some(20));
210            // counter_1 should not be visible in guard2 since it's a new provider
211            assert_eq!(guard2.get_counter("counter_1"), None);
212        } // guard2 dropped here
213    }
214
215    #[test]
216    fn test_counter_accumulation() {
217        let guard = InMemoryMetrics::new();
218
219        crate::declare_static_counter!(ACCUMULATING_COUNTER, "accumulating_counter");
220
221        // Add values multiple times
222        ACCUMULATING_COUNTER.add(1, &[]);
223        assert_eq!(guard.get_counter("accumulating_counter"), Some(1));
224
225        ACCUMULATING_COUNTER.add(2, &[]);
226        assert_eq!(guard.get_counter("accumulating_counter"), Some(3));
227
228        ACCUMULATING_COUNTER.add(7, &[]);
229        assert_eq!(guard.get_counter("accumulating_counter"), Some(10));
230    }
231
232    #[test]
233    fn test_guard_isolation() {
234        // Test that each guard creates its own isolated ManualReader
235        let _guard1 = InMemoryMetrics::new();
236        let _guard2 = InMemoryMetrics::new();
237
238        // Create counters in each guard's context
239        {
240            // Switch to guard1's provider
241            let _temp_guard1 = InMemoryMetrics::new(); // This sets guard1's provider as global
242            crate::declare_static_counter!(ISOLATED_COUNTER_1, "isolated_counter_1");
243            ISOLATED_COUNTER_1.add(100, &[]);
244        }
245
246        {
247            // Switch to guard2's provider
248            let _temp_guard2 = InMemoryMetrics::new(); // This sets guard2's provider as global
249            crate::declare_static_counter!(ISOLATED_COUNTER_2, "isolated_counter_2");
250            ISOLATED_COUNTER_2.add(200, &[]);
251        }
252    }
253}