noisy_float/
lib.rs

1// Copyright 2016-2021 Matthew D. Michelotti
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This crate contains floating point types that panic if they are set
16//! to an illegal value, such as NaN.
17//!
18//! The name "Noisy Float" comes from
19//! the terms "quiet NaN" and "signaling NaN"; "signaling" was too long
20//! to put in a struct/crate name, so "noisy" is used instead, being the opposite
21//! of "quiet."
22//!
23//! The standard types defined in `noisy_float::types` follow the principle
24//! demonstrated by Rust's handling of integer overflow:
25//! a bad arithmetic operation is considered an error,
26//! but it is too costly to check everywhere in optimized builds.
27//! For each floating point number that is created, a `debug_assert!` invocation is used
28//! to check if it is valid or not.
29//! This way, there are guarantees when developing code that floating point
30//! numbers have valid values,
31//! but during a release run there is *no overhead* for using these floating
32//! point types compared to using `f32` or `f64` directly.
33//!
34//! This crate makes use of the num, bounded, signed and floating point traits
35//! in the popular `num_traits` crate.
36//! This crate can be compiled with no_std.
37//!
38//! # Examples
39//! An example using the `R64` type, which corresponds to *finite* `f64` values.
40//!
41//! ```
42//! use noisy_float::prelude::*;
43//!
44//! fn geometric_mean(a: R64, b: R64) -> R64 {
45//!     (a * b).sqrt() //used just like regular floating point numbers
46//! }
47//!
48//! fn mean(a: R64, b: R64) -> R64 {
49//!     (a + b) * 0.5 //the RHS of ops can be the underlying float type
50//! }
51//!
52//! println!(
53//!     "geometric_mean(10.0, 20.0) = {}",
54//!     geometric_mean(r64(10.0), r64(20.0))
55//! );
56//! //prints 14.142...
57//! assert!(mean(r64(10.0), r64(20.0)) == 15.0);
58//! ```
59//!
60//! An example using the `N32` type, which corresponds to *non-NaN* `f32` values.
61//! The float types in this crate are able to implement `Eq` and `Ord` properly,
62//! since NaN is not allowed.
63//!
64//! ```
65//! use noisy_float::prelude::*;
66//!
67//! let values = vec![n32(3.0), n32(-1.5), n32(71.3), N32::infinity()];
68//! assert!(values.iter().cloned().min() == Some(n32(-1.5)));
69//! assert!(values.iter().cloned().max() == Some(N32::infinity()));
70//! ```
71//!
72//! An example converting from R64 to primitive types.
73//!
74//! ```
75//! use noisy_float::prelude::*;
76//! use num_traits::cast::ToPrimitive;
77//!
78//! let value_r64: R64 = r64(1.0);
79//! let value_f64_a: f64 = value_r64.into();
80//! let value_f64_b: f64 = value_r64.raw();
81//! let value_u64: u64 = value_r64.to_u64().unwrap();
82//!
83//! assert!(value_f64_a == value_f64_b);
84//! assert!(value_f64_a as u64 == value_u64);
85//! ```
86//!
87//! # Features
88//!
89//! This crate has the following cargo features:
90//!
91//! - `serde`: Enable serialization for all `NoisyFloats` using serde 1.0 and
92//!   will transparently serialize then as floats
93//! - `approx`: Adds implementations to use `NoisyFloat` with the `approx`
94//!   crate
95
96#![no_std]
97
98#[cfg(feature = "serde")]
99use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error};
100
101pub mod checkers;
102mod float_impl;
103pub mod types;
104
105/// Prelude for the `noisy_float` crate.
106///
107/// This includes all of the types defined in the `noisy_float::types` module,
108/// as well as a re-export of the `Float` trait from the `num_traits` crate.
109/// It is important to have this re-export here, because it allows the user
110/// to access common floating point methods like `abs()`, `sqrt()`, etc.
111pub mod prelude {
112    pub use crate::types::*;
113
114    #[doc(no_inline)]
115    pub use num_traits::Float;
116}
117
118use core::{fmt, marker::PhantomData};
119use num_traits::Float;
120
121/// Trait for checking whether a floating point number is *valid*.
122///
123/// The implementation defines its own criteria for what constitutes a *valid* value.
124pub trait FloatChecker<F> {
125    /// Returns `true` if (and only if) the given floating point number is *valid*
126    /// according to this checker's criteria.
127    ///
128    /// The only hard requirement is that NaN *must* be considered *invalid*
129    /// for all implementations of `FloatChecker`.
130    fn check(value: F) -> bool;
131
132    /// A function that may panic if the floating point number is *invalid*.
133    ///
134    /// Should either call `assert!(check(value), ...)` or `debug_assert!(check(value), ...)`.
135    fn assert(value: F);
136}
137
138/// A floating point number with a restricted set of legal values.
139///
140/// Typical users will not need to access this struct directly, but
141/// can instead use the type aliases found in the module `noisy_float::types`.
142/// However, this struct together with a `FloatChecker` implementation can be used
143/// to define custom behavior.
144///
145/// The underlying float type is `F`, usually `f32` or `f64`.
146/// Valid values for the float are determined by the float checker `C`.
147/// If an invalid value would ever be returned from a method on this type,
148/// the method will panic instead, using either `assert!` or `debug_assert!`
149/// as defined by the float checker.
150/// The exception to this rule is for methods that return an `Option` containing
151/// a `NoisyFloat`, in which case the result would be `None` if the value is invalid.
152#[repr(transparent)]
153pub struct NoisyFloat<F: Float, C: FloatChecker<F>> {
154    value: F,
155    checker: PhantomData<C>,
156}
157
158impl<F: Float, C: FloatChecker<F>> NoisyFloat<F, C> {
159    /// Constructs a `NoisyFloat` with the given value.
160    ///
161    /// Uses the `FloatChecker` to assert that the value is valid.
162    #[inline]
163    pub fn new(value: F) -> Self {
164        C::assert(value);
165        Self::unchecked_new_generic(value)
166    }
167
168    #[inline]
169    fn unchecked_new_generic(value: F) -> Self {
170        NoisyFloat {
171            value,
172            checker: PhantomData,
173        }
174    }
175
176    /// Tries to construct a `NoisyFloat` with the given value.
177    ///
178    /// Returns `None` if the value is invalid.
179    #[inline]
180    pub fn try_new(value: F) -> Option<Self> {
181        if C::check(value) {
182            Some(NoisyFloat {
183                value,
184                checker: PhantomData,
185            })
186        } else {
187            None
188        }
189    }
190
191    /// Converts the value in-place to a reference to a `NoisyFloat`.
192    ///
193    /// Uses the `FloatChecker` to assert that the value is valid.
194    #[inline]
195    pub fn borrowed(value: &F) -> &Self {
196        C::assert(*value);
197        Self::unchecked_borrowed(value)
198    }
199
200    #[inline]
201    fn unchecked_borrowed(value: &F) -> &Self {
202        // This is safe because `NoisyFloat` is a thin wrapper around the
203        // floating-point type.
204        unsafe { &*(value as *const F as *const Self) }
205    }
206
207    /// Tries to convert the value in-place to a reference to a `NoisyFloat`.
208    ///
209    /// Returns `None` if the value is invalid.
210    #[inline]
211    pub fn try_borrowed(value: &F) -> Option<&Self> {
212        if C::check(*value) {
213            Some(Self::unchecked_borrowed(value))
214        } else {
215            None
216        }
217    }
218
219    /// Converts the value in-place to a mutable reference to a `NoisyFloat`.
220    ///
221    /// Uses the `FloatChecker` to assert that the value is valid.
222    #[inline]
223    pub fn borrowed_mut(value: &mut F) -> &mut Self {
224        C::assert(*value);
225        Self::unchecked_borrowed_mut(value)
226    }
227
228    #[inline]
229    fn unchecked_borrowed_mut(value: &mut F) -> &mut Self {
230        // This is safe because `NoisyFloat` is a thin wrapper around the
231        // floating-point type.
232        unsafe { &mut *(value as *mut F as *mut Self) }
233    }
234
235    /// Tries to convert the value in-place to a mutable reference to a `NoisyFloat`.
236    ///
237    /// Returns `None` if the value is invalid.
238    #[inline]
239    pub fn try_borrowed_mut(value: &mut F) -> Option<&mut Self> {
240        if C::check(*value) {
241            Some(Self::unchecked_borrowed_mut(value))
242        } else {
243            None
244        }
245    }
246
247    /// Constructs a `NoisyFloat` with the given `f32` value.
248    ///
249    /// May panic not only by the `FloatChecker` but also
250    /// by unwrapping the result of a `NumCast` invocation for type `F`,
251    /// although the later should not occur in normal situations.
252    #[inline]
253    pub fn from_f32(value: f32) -> Self {
254        Self::new(F::from(value).unwrap())
255    }
256
257    /// Constructs a `NoisyFloat` with the given `f64` value.
258    ///
259    /// May panic not only by the `FloatChecker` but also
260    /// by unwrapping the result of a `NumCast` invocation for type `F`,
261    /// although the later should not occur in normal situations.
262    #[inline]
263    pub fn from_f64(value: f64) -> Self {
264        Self::new(F::from(value).unwrap())
265    }
266
267    /// Returns the underlying float value.
268    #[inline]
269    pub fn raw(self) -> F {
270        self.value
271    }
272
273    /// Compares and returns the minimum of two values.
274    ///
275    /// This method exists to disambiguate between `num_traits::Float.min` and `std::cmp::Ord.min`.
276    #[inline]
277    pub fn min(self, other: Self) -> Self {
278        Ord::min(self, other)
279    }
280
281    /// Compares and returns the maximum of two values.
282    ///
283    /// This method exists to disambiguate between `num_traits::Float.max` and `std::cmp::Ord.max`.
284    #[inline]
285    pub fn max(self, other: Self) -> Self {
286        Ord::max(self, other)
287    }
288}
289
290impl<F: Float + Default, C: FloatChecker<F>> Default for NoisyFloat<F, C> {
291    #[inline]
292    fn default() -> Self {
293        Self::new(F::default())
294    }
295}
296
297impl<F: Float + fmt::Debug, C: FloatChecker<F>> fmt::Debug for NoisyFloat<F, C> {
298    #[inline]
299    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
300        fmt::Debug::fmt(&self.value, f)
301    }
302}
303
304impl<F: Float + fmt::Display, C: FloatChecker<F>> fmt::Display for NoisyFloat<F, C> {
305    #[inline]
306    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
307        fmt::Display::fmt(&self.value, f)
308    }
309}
310
311impl<F: Float + fmt::LowerExp, C: FloatChecker<F>> fmt::LowerExp for NoisyFloat<F, C> {
312    #[inline]
313    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
314        fmt::LowerExp::fmt(&self.value, f)
315    }
316}
317
318impl<F: Float + fmt::UpperExp, C: FloatChecker<F>> fmt::UpperExp for NoisyFloat<F, C> {
319    #[inline]
320    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
321        fmt::UpperExp::fmt(&self.value, f)
322    }
323}
324
325#[cfg(feature = "serde")]
326impl<F: Float + Serialize, C: FloatChecker<F>> Serialize for NoisyFloat<F, C> {
327    fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
328        self.value.serialize(ser)
329    }
330}
331
332#[cfg(feature = "serde")]
333impl<'de, F: Float + Deserialize<'de>, C: FloatChecker<F>> Deserialize<'de> for NoisyFloat<F, C> {
334    fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
335        let value = F::deserialize(de)?;
336        Self::try_new(value).ok_or_else(|| D::Error::custom("invalid NoisyFloat"))
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    extern crate std;
343    use std::prelude::v1::*;
344
345    use crate::prelude::*;
346    #[cfg(feature = "serde")]
347    use serde_derive::{Deserialize, Serialize};
348    #[cfg(feature = "serde")]
349    use serde_json;
350    use std::{
351        f32,
352        f64::{self, consts},
353        hash::{Hash, Hasher},
354        mem::{align_of, size_of},
355    };
356
357    #[test]
358    fn smoke_test() {
359        assert_eq!(n64(1.0) + 2.0, 3.0);
360        assert_ne!(n64(3.0), n64(2.9));
361        assert!(r64(1.0) < 2.0);
362        let mut value = n64(18.0);
363        value %= n64(5.0);
364        assert_eq!(-value, n64(-3.0));
365        assert_eq!(r64(1.0).exp(), consts::E);
366        assert_eq!((N64::try_new(1.0).unwrap() / N64::infinity()), 0.0);
367        assert_eq!(N64::from_f32(f32::INFINITY), N64::from_f64(f64::INFINITY));
368        assert_eq!(R64::try_new(f64::NEG_INFINITY), None);
369        assert_eq!(N64::try_new(f64::NAN), None);
370        assert_eq!(R64::try_new(f64::NAN), None);
371        assert_eq!(N64::try_borrowed(&f64::NAN), None);
372        let mut nan = f64::NAN;
373        assert_eq!(N64::try_borrowed_mut(&mut nan), None);
374    }
375
376    #[test]
377    fn ensure_layout() {
378        assert_eq!(size_of::<N32>(), size_of::<f32>());
379        assert_eq!(align_of::<N32>(), align_of::<f32>());
380
381        assert_eq!(size_of::<N64>(), size_of::<f64>());
382        assert_eq!(align_of::<N64>(), align_of::<f64>());
383    }
384
385    #[test]
386    fn borrowed_casts() {
387        assert_eq!(R64::borrowed(&3.14), &3.14);
388        assert_eq!(N64::borrowed(&[f64::INFINITY; 2][0]), &f64::INFINITY);
389        assert_eq!(N64::borrowed_mut(&mut 2.72), &mut 2.72);
390    }
391
392    #[test]
393    fn test_convert() {
394        assert_eq!(f32::from(r32(3.0)), 3.0f32);
395        assert_eq!(f64::from(r32(5.0)), 5.0f64);
396        assert_eq!(f64::from(r64(7.0)), 7.0f64);
397    }
398
399    #[test]
400    #[cfg(debug_assertions)]
401    #[should_panic]
402    fn n64_nan() {
403        let _ = n64(0.0) / n64(0.0);
404    }
405
406    #[test]
407    #[cfg(debug_assertions)]
408    #[should_panic]
409    fn r64_nan() {
410        let _ = r64(0.0) / r64(0.0);
411    }
412
413    #[test]
414    #[cfg(debug_assertions)]
415    #[should_panic]
416    fn r64_infinity() {
417        let _ = r64(1.0) / r64(0.0);
418    }
419
420    #[test]
421    fn resolves_min_max() {
422        assert_eq!(r64(1.0).min(r64(3.0)), r64(1.0));
423        assert_eq!(r64(1.0).max(r64(3.0)), r64(3.0));
424    }
425
426    #[test]
427    fn epsilon() {
428        assert_eq!(R32::epsilon(), f32::EPSILON);
429        assert_eq!(R64::epsilon(), f64::EPSILON);
430    }
431
432    #[test]
433    fn test_try_into() {
434        use std::convert::{TryFrom, TryInto};
435        let _: R64 = 1.0.try_into().unwrap();
436        let _ = R64::try_from(f64::INFINITY).unwrap_err();
437    }
438
439    struct TestHasher {
440        bytes: Vec<u8>,
441    }
442
443    impl Hasher for TestHasher {
444        fn finish(&self) -> u64 {
445            panic!("unexpected Hasher.finish invocation")
446        }
447        fn write(&mut self, bytes: &[u8]) {
448            self.bytes.extend_from_slice(bytes)
449        }
450    }
451
452    fn hash_bytes<T: Hash>(value: T) -> Vec<u8> {
453        let mut hasher = TestHasher { bytes: Vec::new() };
454        value.hash(&mut hasher);
455        hasher.bytes
456    }
457
458    #[test]
459    fn test_hash() {
460        assert_eq!(hash_bytes(r64(10.3)), hash_bytes(10.3f64.to_bits()));
461        assert_ne!(hash_bytes(r64(10.3)), hash_bytes(10.4f64.to_bits()));
462        assert_eq!(hash_bytes(r32(10.3)), hash_bytes(10.3f32.to_bits()));
463        assert_ne!(hash_bytes(r32(10.3)), hash_bytes(10.4f32.to_bits()));
464
465        assert_eq!(
466            hash_bytes(N64::infinity()),
467            hash_bytes(f64::INFINITY.to_bits())
468        );
469        assert_eq!(
470            hash_bytes(N64::neg_infinity()),
471            hash_bytes(f64::NEG_INFINITY.to_bits())
472        );
473
474        // positive and negative zero should have the same hashes
475        assert_eq!(hash_bytes(r64(0.0)), hash_bytes(0.0f64.to_bits()));
476        assert_eq!(hash_bytes(r64(-0.0)), hash_bytes(0.0f64.to_bits()));
477        assert_eq!(hash_bytes(r32(0.0)), hash_bytes(0.0f32.to_bits()));
478        assert_eq!(hash_bytes(r32(-0.0)), hash_bytes(0.0f32.to_bits()));
479    }
480
481    #[cfg(feature = "serde")]
482    #[test]
483    fn serialize_transparently_as_float() {
484        let num = R32::new(3.14);
485        let should_be = "3.14";
486
487        let got = serde_json::to_string(&num).unwrap();
488        assert_eq!(got, should_be);
489    }
490
491    #[cfg(feature = "serde")]
492    #[test]
493    fn deserialize_transparently_as_float() {
494        let src = "3.14";
495        let should_be = R32::new(3.14);
496
497        let got: R32 = serde_json::from_str(src).unwrap();
498        assert_eq!(got, should_be);
499    }
500
501    #[cfg(feature = "serde")]
502    #[test]
503    fn deserialize_invalid_float() {
504        use crate::{FloatChecker, NoisyFloat};
505        struct PositiveChecker;
506        impl FloatChecker<f64> for PositiveChecker {
507            fn check(value: f64) -> bool {
508                value > 0.
509            }
510            fn assert(value: f64) {
511                debug_assert!(Self::check(value))
512            }
513        }
514
515        let src = "-1.0";
516        let got: Result<NoisyFloat<f64, PositiveChecker>, _> = serde_json::from_str(src);
517        assert!(got.is_err());
518    }
519
520    // Make sure you can use serde_derive with noisy floats.
521    #[cfg(feature = "serde")]
522    #[derive(Debug, PartialEq, Serialize, Deserialize)]
523    struct Dummy {
524        value: N64,
525    }
526
527    #[cfg(feature = "serde")]
528    #[test]
529    fn deserialize_struct_containing_n64() {
530        let src = r#"{ "value": 3.14 }"#;
531        let should_be = Dummy { value: n64(3.14) };
532
533        let got: Dummy = serde_json::from_str(src).unwrap();
534        assert_eq!(got, should_be);
535    }
536
537    #[cfg(feature = "serde")]
538    #[test]
539    fn serialize_struct_containing_n64() {
540        let src = Dummy { value: n64(3.14) };
541        let should_be = r#"{"value":3.14}"#;
542
543        let got = serde_json::to_string(&src).unwrap();
544        assert_eq!(got, should_be);
545    }
546
547    #[cfg(feature = "approx")]
548    #[test]
549    fn approx_assert_eq() {
550        use approx::{assert_abs_diff_eq, assert_relative_eq, assert_ulps_eq};
551
552        let lhs = r64(0.1000000000000001);
553        let rhs = r64(0.1);
554
555        assert_abs_diff_eq!(lhs, rhs);
556        assert_relative_eq!(lhs, rhs);
557        assert_ulps_eq!(lhs, rhs);
558    }
559
560    #[test]
561    fn const_functions() {
562        const A: N32 = N32::unchecked_new(1.0);
563        const B: N64 = N64::unchecked_new(2.0);
564        const C: R32 = R32::unchecked_new(3.0);
565        const D: R64 = R64::unchecked_new(4.0);
566
567        const A_RAW: f32 = A.const_raw();
568        const B_RAW: f64 = B.const_raw();
569        const C_RAW: f32 = C.const_raw();
570        const D_RAW: f64 = D.const_raw();
571
572        assert_eq!(A_RAW, 1.0);
573        assert_eq!(B_RAW, 2.0);
574        assert_eq!(C_RAW, 3.0);
575        assert_eq!(D_RAW, 4.0);
576    }
577}