evobench_tools/serde_types/
priority.rs

1// There's <https://crates.io/crates/ordered-float> but I haven't
2// reviewed it, except I saw that it (in version 3) doesn't use
3// TryFrom to construct floats, instead ordering NaN at the end. This
4// doesn't seem right for priority use, thus making our own.
5
6use std::{
7    fmt::Display,
8    ops::{Add, Neg},
9    str::FromStr,
10};
11
12use anyhow::anyhow;
13use serde::de::{Error, Visitor};
14
15/// A priority level. The level is any orderable instance of a `f64`
16/// value (i.e. not NAN).
17#[derive(Debug, PartialEq, PartialOrd, Clone, Copy)]
18pub struct Priority(f64);
19
20impl Eq for Priority {}
21
22impl Ord for Priority {
23    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
24        self.partial_cmp(other)
25            .expect("always succeeds due to check in constructor")
26    }
27}
28
29#[derive(Debug, thiserror::Error)]
30#[error("not a comparable number: {0}")]
31pub struct NonComparableNumber(f64);
32
33impl Priority {
34    pub const HIGH: Priority = Priority::new_unchecked(1.);
35    pub const NORMAL: Priority = Priority::new_unchecked(0.);
36    pub const LOW: Priority = Priority::new_unchecked(-1.);
37
38    /// This does not verify that `value` is comparable. Expect panics
39    /// and other problems if it isn't! This function only exists for
40    /// `const` purposes.
41    pub const fn new_unchecked(value: f64) -> Self {
42        Self(value)
43    }
44
45    pub fn new(value: f64) -> Result<Self, NonComparableNumber> {
46        match value.partial_cmp(&1.23) {
47            Some(_) => Ok(Self(value)),
48            None => Err(NonComparableNumber(value)),
49        }
50    }
51
52    pub fn add(self, difference: f64) -> Result<Self, NonComparableNumber> {
53        Self::new(self.0 + difference)
54    }
55
56    pub fn sub(self, difference: f64) -> Result<Self, NonComparableNumber> {
57        Self::new(self.0 - difference)
58    }
59}
60
61impl Neg for Priority {
62    type Output = Priority;
63
64    fn neg(self) -> Self::Output {
65        Self(-self.0)
66    }
67}
68
69impl Add for Priority {
70    type Output = Result<Priority, NonComparableNumber>;
71
72    fn add(self, rhs: Self) -> Self::Output {
73        Self::new(self.0 + rhs.0)
74    }
75}
76
77impl Default for Priority {
78    fn default() -> Self {
79        Self::new_unchecked(0.)
80    }
81}
82
83impl TryFrom<f64> for Priority {
84    type Error = NonComparableNumber;
85
86    fn try_from(value: f64) -> Result<Self, Self::Error> {
87        Self::new(value)
88    }
89}
90
91impl From<Priority> for f64 {
92    fn from(value: Priority) -> Self {
93        value.0
94    }
95}
96
97impl TryFrom<f32> for Priority {
98    type Error = NonComparableNumber;
99
100    fn try_from(value: f32) -> Result<Self, Self::Error> {
101        Self::new(value.into())
102    }
103}
104
105impl FromStr for Priority {
106    type Err = anyhow::Error;
107
108    fn from_str(s: &str) -> Result<Self, Self::Err> {
109        let s = s.trim();
110        match s {
111            "high" => Ok(Priority::HIGH),
112            "normal" => Ok(Priority::NORMAL),
113            "low" => Ok(Priority::LOW),
114            _ => Ok(Priority::new(s.parse().map_err(|e| {
115                anyhow!("parsing the string {s:?} as Priority: {e:#}")
116            })?)?),
117        }
118    }
119}
120
121// Get rid of 1.250000000000000001 style formatting
122fn format_rounded(prefix: &str, value: f64) -> String {
123    let mut rounded = format!("{prefix}{:.6}", value);
124    let trimmed = rounded.trim_end_matches('0').trim_end_matches('.');
125    rounded.truncate(trimmed.len());
126    rounded
127}
128
129#[test]
130fn t_format_rounded() {
131    let t = |v: f64| format_rounded("", v);
132    assert_eq!(t(1.250000000001), "1.25");
133    assert_eq!(t(-1.250000000001), "-1.25");
134    assert_eq!(t(1.24999999), "1.25");
135    assert_eq!(t(1.25), "1.25");
136    assert_eq!(t(1.0), "1");
137    assert_eq!(t(10.0), "10");
138    assert_eq!(t(10.00100001), "10.001");
139}
140
141impl Display for Priority {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        let x = self.0;
144        let prefix = if x.is_sign_negative() { "" } else { " " };
145        let s = format_rounded(prefix, x);
146        f.write_str(&s)
147    }
148}
149
150struct OurVisitor;
151impl<'de> Visitor<'de> for OurVisitor {
152    type Value = Priority;
153
154    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
155        formatter
156            .write_str("a floating point number or one of the strings 'high', 'normal', or 'low")
157    }
158
159    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
160    where
161        E: serde::de::Error,
162    {
163        Priority::from_str(v).map_err(E::custom)
164    }
165
166    // All 3 of these number implementations are necessary; and i16 or
167    // i64 alone didn't work for the number 0!
168
169    fn visit_f64<E: Error>(self, v: f64) -> Result<Self::Value, E> {
170        Priority::new(v).map_err(|e| Error::custom(e))
171    }
172
173    fn visit_u64<E: Error>(self, v: u64) -> Result<Self::Value, E> {
174        Priority::new(v as f64).map_err(|e| Error::custom(e))
175    }
176
177    fn visit_i64<E: Error>(self, v: i64) -> Result<Self::Value, E> {
178        Priority::new(v as f64).map_err(|e| Error::custom(e))
179    }
180}
181
182impl<'de> serde::Deserialize<'de> for Priority {
183    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
184        deserializer.deserialize_any(OurVisitor)
185    }
186}
187
188impl serde::Serialize for Priority {
189    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
190    where
191        S: serde::Serializer,
192    {
193        if *self == Priority::HIGH {
194            serializer.serialize_str("high")
195        } else if *self == Priority::NORMAL {
196            serializer.serialize_str("normal")
197        } else if *self == Priority::LOW {
198            serializer.serialize_str("low")
199        } else {
200            serializer.serialize_f64(self.0)
201        }
202    }
203}