diff --git a/src/lib.rs b/src/lib.rs index a22a1ce..80f810e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ use alloc::{vec, vec::Vec}; #[cfg(not(feature = "std"))] use core as std; +use std::marker::PhantomData; mod range; @@ -52,18 +53,23 @@ fn div_rem(x: usize, d: usize) -> (usize, usize) { /// capacity can grow using the `grow` method). #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FixedBitSet { +pub struct FixedBitSet { data: Vec, /// length in bits length: usize, + kind: PhantomData, } -impl FixedBitSet { +impl FixedBitSet +where + K: From + Into, +{ /// Create a new empty **FixedBitSet**. - pub const fn new() -> Self { + pub fn new() -> Self { FixedBitSet { data: Vec::new(), length: 0, + kind: PhantomData, } } @@ -75,6 +81,7 @@ impl FixedBitSet { FixedBitSet { data: vec![0; blocks], length: bits, + kind: PhantomData, } } @@ -88,7 +95,7 @@ impl FixedBitSet { /// For example: /// ``` /// let data = vec![4]; - /// let bs = fixedbitset::FixedBitSet::with_capacity_and_blocks(4, data); + /// let bs = fixedbitset::FixedBitSet::::with_capacity_and_blocks(4, data); /// assert_eq!(format!("{:b}", bs), "0010"); /// ``` pub fn with_capacity_and_blocks>(bits: usize, blocks: I) -> Self { @@ -106,7 +113,11 @@ impl FixedBitSet { *data.get_unchecked_mut(block) &= !mask; } } - FixedBitSet { data, length: bits } + FixedBitSet { + data, + length: bits.into(), + kind: PhantomData, + } } /// Grow capacity to **bits**, all new bits initialized to zero @@ -121,14 +132,14 @@ impl FixedBitSet { /// Return the length of the [`FixedBitSet`] in bits. #[inline] - pub fn len(&self) -> usize { - self.length + pub fn len(&self) -> K { + self.length.into() } /// Return if the [`FixedBitSet`] is empty. #[inline] pub fn is_empty(&self) -> bool { - self.len() == 0 + self.length == 0 } /// Return **true** if the bit is enabled in the **FixedBitSet**, @@ -138,7 +149,8 @@ impl FixedBitSet { /// /// Note: Also available with index syntax: `bitset[bit]`. #[inline] - pub fn contains(&self, bit: usize) -> bool { + pub fn contains(&self, bit: K) -> bool { + let bit = bit.into(); let (block, i) = div_rem(bit, BITS); match self.data.get(block) { None => false, @@ -158,7 +170,8 @@ impl FixedBitSet { /// /// **Panics** if **bit** is out of bounds. #[inline] - pub fn insert(&mut self, bit: usize) { + pub fn insert(&mut self, bit: K) { + let bit = bit.into(); assert!( bit < self.length, "insert at index {} exceeds fixbitset size {}", @@ -175,7 +188,8 @@ impl FixedBitSet { /// /// **Panics** if **bit** is out of bounds. #[inline] - pub fn put(&mut self, bit: usize) -> bool { + pub fn put(&mut self, bit: K) -> bool { + let bit = bit.into(); assert!( bit < self.length, "put at index {} exceeds fixbitset size {}", @@ -194,7 +208,8 @@ impl FixedBitSet { /// /// ***Panics*** if **bit** is out of bounds #[inline] - pub fn toggle(&mut self, bit: usize) { + pub fn toggle(&mut self, bit: K) { + let bit = bit.into(); assert!( bit < self.length, "toggle at index {} exceeds fixbitset size {}", @@ -208,7 +223,8 @@ impl FixedBitSet { } /// **Panics** if **bit** is out of bounds. #[inline] - pub fn set(&mut self, bit: usize, enabled: bool) { + pub fn set(&mut self, bit: K, enabled: bool) { + let bit = bit.into(); assert!( bit < self.length, "set at index {} exceeds fixbitset size {}", @@ -230,7 +246,9 @@ impl FixedBitSet { /// /// **Panics** if **to** is out of bounds. #[inline] - pub fn copy_bit(&mut self, from: usize, to: usize) { + pub fn copy_bit(&mut self, from: K, to: K) { + let from = from.into(); + let to = to.into(); assert!( to < self.length, "copy at index {} exceeds fixbitset size {}", @@ -238,7 +256,7 @@ impl FixedBitSet { self.length ); let (to_block, t) = div_rem(to, BITS); - let enabled = self.contains(from); + let enabled = self.contains(from.into()); unsafe { let to_elt = self.data.get_unchecked_mut(to_block); if enabled { @@ -323,23 +341,25 @@ impl FixedBitSet { /// /// Iterator element is the index of the `1` bit, type `usize`. #[inline] - pub fn ones(&self) -> Ones { + pub fn ones(&self) -> Ones { match self.as_slice().split_first() { Some((&block, rem)) => Ones { bitset: block, block_idx: 0, remaining_blocks: rem, + kind: PhantomData, }, None => Ones { bitset: 0, block_idx: 0, remaining_blocks: &[], + kind: PhantomData, }, } } /// Returns a lazy iterator over the intersection of two `FixedBitSet`s - pub fn intersection<'a>(&'a self, other: &'a FixedBitSet) -> Intersection<'a> { + pub fn intersection<'a>(&'a self, other: &'a FixedBitSet) -> Intersection<'a, K> { Intersection { iter: self.ones(), other, @@ -347,7 +367,7 @@ impl FixedBitSet { } /// Returns a lazy iterator over the union of two `FixedBitSet`s. - pub fn union<'a>(&'a self, other: &'a FixedBitSet) -> Union<'a> { + pub fn union<'a>(&'a self, other: &'a FixedBitSet) -> Union<'a, K> { Union { iter: self.ones().chain(other.difference(self)), } @@ -355,7 +375,7 @@ impl FixedBitSet { /// Returns a lazy iterator over the difference of two `FixedBitSet`s. The difference of `a` /// and `b` is the elements of `a` which are not in `b`. - pub fn difference<'a>(&'a self, other: &'a FixedBitSet) -> Difference<'a> { + pub fn difference<'a>(&'a self, other: &'a FixedBitSet) -> Difference<'a, K> { Difference { iter: self.ones(), other, @@ -364,7 +384,10 @@ impl FixedBitSet { /// Returns a lazy iterator over the symmetric difference of two `FixedBitSet`s. /// The symmetric difference of `a` and `b` is the elements of one, but not both, sets. - pub fn symmetric_difference<'a>(&'a self, other: &'a FixedBitSet) -> SymmetricDifference<'a> { + pub fn symmetric_difference<'a>( + &'a self, + other: &'a FixedBitSet, + ) -> SymmetricDifference<'a, K> { SymmetricDifference { iter: self.difference(other).chain(other.difference(self)), } @@ -373,9 +396,9 @@ impl FixedBitSet { /// In-place union of two `FixedBitSet`s. /// /// On calling this method, `self`'s capacity may be increased to match `other`'s. - pub fn union_with(&mut self, other: &FixedBitSet) { - if other.len() >= self.len() { - self.grow(other.len()); + pub fn union_with(&mut self, other: &FixedBitSet) { + if other.length >= self.length { + self.grow(other.length); } for (x, y) in self.data.iter_mut().zip(other.data.iter()) { *x |= *y; @@ -385,7 +408,7 @@ impl FixedBitSet { /// In-place intersection of two `FixedBitSet`s. /// /// On calling this method, `self`'s capacity will remain the same as before. - pub fn intersect_with(&mut self, other: &FixedBitSet) { + pub fn intersect_with(&mut self, other: &FixedBitSet) { for (x, y) in self.data.iter_mut().zip(other.data.iter()) { *x &= *y; } @@ -398,7 +421,7 @@ impl FixedBitSet { /// In-place difference of two `FixedBitSet`s. /// /// On calling this method, `self`'s capacity will remain the same as before. - pub fn difference_with(&mut self, other: &FixedBitSet) { + pub fn difference_with(&mut self, other: &FixedBitSet) { for (x, y) in self.data.iter_mut().zip(other.data.iter()) { *x &= !*y; } @@ -414,9 +437,9 @@ impl FixedBitSet { /// In-place symmetric difference of two `FixedBitSet`s. /// /// On calling this method, `self`'s capacity may be increased to match `other`'s. - pub fn symmetric_difference_with(&mut self, other: &FixedBitSet) { - if other.len() >= self.len() { - self.grow(other.len()); + pub fn symmetric_difference_with(&mut self, other: &FixedBitSet) { + if other.length >= self.length { + self.grow(other.length); } for (x, y) in self.data.iter_mut().zip(other.data.iter()) { *x ^= *y; @@ -425,7 +448,7 @@ impl FixedBitSet { /// Returns `true` if `self` has no elements in common with `other`. This /// is equivalent to checking for an empty intersection. - pub fn is_disjoint(&self, other: &FixedBitSet) -> bool { + pub fn is_disjoint(&self, other: &FixedBitSet) -> bool { self.data .iter() .zip(other.data.iter()) @@ -434,7 +457,7 @@ impl FixedBitSet { /// Returns `true` if the set is a subset of another, i.e. `other` contains /// at least all the values in `self`. - pub fn is_subset(&self, other: &FixedBitSet) -> bool { + pub fn is_subset(&self, other: &FixedBitSet) -> bool { self.data .iter() .zip(other.data.iter()) @@ -444,19 +467,21 @@ impl FixedBitSet { /// Returns `true` if the set is a superset of another, i.e. `self` contains /// at least all the values in `other`. - pub fn is_superset(&self, other: &FixedBitSet) -> bool { + pub fn is_superset(&self, other: &FixedBitSet) -> bool { other.is_subset(self) } } - -impl Binary for FixedBitSet { +impl Binary for FixedBitSet +where + K: From + Into, +{ fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { if f.alternate() { f.write_str("0b")?; } for i in 0..self.length { - if self[i] { + if self[i.into()] { f.write_char('1')?; } else { f.write_char('0')?; @@ -467,7 +492,10 @@ impl Binary for FixedBitSet { } } -impl Display for FixedBitSet { +impl Display for FixedBitSet +where + K: From + Into, +{ fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { Binary::fmt(&self, f) } @@ -476,19 +504,23 @@ impl Display for FixedBitSet { /// An iterator producing elements in the difference of two sets. /// /// This struct is created by the [`FixedBitSet::difference`] method. -pub struct Difference<'a> { - iter: Ones<'a>, - other: &'a FixedBitSet, +pub struct Difference<'a, K> { + iter: Ones<'a, K>, + other: &'a FixedBitSet, } -impl<'a> Iterator for Difference<'a> { - type Item = usize; +impl<'a, K> Iterator for Difference<'a, K> +where + K: From + Into, +{ + type Item = K; #[inline] fn next(&mut self) -> Option { while let Some(nxt) = self.iter.next() { - if !self.other.contains(nxt) { - return Some(nxt); + let nxt: usize = nxt.into(); // workaround missing Copy bound on K + if !self.other.contains(nxt.into()) { + return Some(nxt.into()); } } None @@ -498,12 +530,15 @@ impl<'a> Iterator for Difference<'a> { /// An iterator producing elements in the symmetric difference of two sets. /// /// This struct is created by the [`FixedBitSet::symmetric_difference`] method. -pub struct SymmetricDifference<'a> { - iter: Chain, Difference<'a>>, +pub struct SymmetricDifference<'a, K> { + iter: Chain, Difference<'a, K>>, } -impl<'a> Iterator for SymmetricDifference<'a> { - type Item = usize; +impl<'a, K> Iterator for SymmetricDifference<'a, K> +where + K: From + Into, +{ + type Item = K; #[inline] fn next(&mut self) -> Option { @@ -514,19 +549,23 @@ impl<'a> Iterator for SymmetricDifference<'a> { /// An iterator producing elements in the intersection of two sets. /// /// This struct is created by the [`FixedBitSet::intersection`] method. -pub struct Intersection<'a> { - iter: Ones<'a>, - other: &'a FixedBitSet, +pub struct Intersection<'a, K> { + iter: Ones<'a, K>, + other: &'a FixedBitSet, } -impl<'a> Iterator for Intersection<'a> { - type Item = usize; // the bit position of the '1' +impl<'a, K> Iterator for Intersection<'a, K> +where + K: From + Into, +{ + type Item = K; // the bit position of the '1' #[inline] fn next(&mut self) -> Option { while let Some(nxt) = self.iter.next() { - if self.other.contains(nxt) { - return Some(nxt); + let nxt: usize = nxt.into(); // workaround missing Copy bound on K + if self.other.contains(nxt.into()) { + return Some(nxt.into()); } } None @@ -536,12 +575,15 @@ impl<'a> Iterator for Intersection<'a> { /// An iterator producing elements in the union of two sets. /// /// This struct is created by the [`FixedBitSet::union`] method. -pub struct Union<'a> { - iter: Chain, Difference<'a>>, +pub struct Union<'a, K> { + iter: Chain, Difference<'a, K>>, } -impl<'a> Iterator for Union<'a> { - type Item = usize; +impl<'a, K> Iterator for Union<'a, K> +where + K: From + Into, +{ + type Item = K; #[inline] fn next(&mut self) -> Option { @@ -611,14 +653,18 @@ impl Iterator for Masks { /// An iterator producing the indices of the set bit in a set. /// /// This struct is created by the [`FixedBitSet::ones`] method. -pub struct Ones<'a> { +pub struct Ones<'a, K> { bitset: Block, block_idx: usize, remaining_blocks: &'a [Block], + kind: PhantomData, } -impl<'a> Iterator for Ones<'a> { - type Item = usize; // the bit position of the '1' +impl<'a, K> Iterator for Ones<'a, K> +where + K: From + Into, +{ + type Item = K; // the bit position of the '1' #[inline] fn next(&mut self) -> Option { @@ -633,16 +679,17 @@ impl<'a> Iterator for Ones<'a> { let t = self.bitset & (0 as Block).wrapping_sub(self.bitset); let r = self.bitset.trailing_zeros() as usize; self.bitset ^= t; - Some(self.block_idx * BITS + r) + Some((self.block_idx * BITS + r).into()) } } -impl Clone for FixedBitSet { +impl Clone for FixedBitSet { #[inline] fn clone(&self) -> Self { FixedBitSet { data: self.data.clone(), length: self.length, + kind: self.kind, } } } @@ -652,11 +699,14 @@ impl Clone for FixedBitSet { /// /// Note: bits outside the capacity are always disabled, and thus /// indexing a FixedBitSet will not panic. -impl Index for FixedBitSet { +impl Index for FixedBitSet +where + K: From + Into, +{ type Output = bool; #[inline] - fn index(&self, bit: usize) -> &bool { + fn index(&self, bit: K) -> &bool { if self.contains(bit) { &true } else { @@ -666,33 +716,43 @@ impl Index for FixedBitSet { } /// Sets the bit at index **i** to **true** for each item **i** in the input **src**. -impl Extend for FixedBitSet { - fn extend>(&mut self, src: I) { +impl Extend for FixedBitSet +where + K: From + Into, +{ + fn extend>(&mut self, src: I) { let iter = src.into_iter(); for i in iter { - if i >= self.len() { - self.grow(i + 1); + let i = i.into(); + if i >= self.length { + self.grow((i + 1).into()); } - self.put(i); + self.put(i.into()); } } } /// Return a FixedBitSet containing bits set to **true** for every bit index in /// the iterator, other bits are set to **false**. -impl FromIterator for FixedBitSet { - fn from_iter>(src: I) -> Self { - let mut fbs = FixedBitSet::with_capacity(0); +impl FromIterator for FixedBitSet +where + K: From + Into, +{ + fn from_iter>(src: I) -> Self { + let mut fbs = FixedBitSet::::new(); fbs.extend(src); fbs } } -impl<'a> BitAnd for &'a FixedBitSet { - type Output = FixedBitSet; - fn bitand(self, other: &FixedBitSet) -> FixedBitSet { +impl<'a, K> BitAnd for &'a FixedBitSet +where + K: From + Into, +{ + type Output = FixedBitSet; + fn bitand(self, other: &FixedBitSet) -> FixedBitSet { let (short, long) = { - if self.len() <= other.len() { + if self.length <= other.length { (&self.data, &other.data) } else { (&other.data, &self.data) @@ -702,28 +762,41 @@ impl<'a> BitAnd for &'a FixedBitSet { for (data, block) in data.iter_mut().zip(long.iter()) { *data &= *block; } - let len = std::cmp::min(self.len(), other.len()); - FixedBitSet { data, length: len } + let len = std::cmp::min(self.length, other.length); + FixedBitSet { + data, + length: len, + kind: PhantomData, + } } } -impl<'a> BitAndAssign for FixedBitSet { +impl<'a, K> BitAndAssign for FixedBitSet +where + K: From + Into, +{ fn bitand_assign(&mut self, other: Self) { self.intersect_with(&other); } } -impl<'a> BitAndAssign<&Self> for FixedBitSet { +impl<'a, K> BitAndAssign<&Self> for FixedBitSet +where + K: From + Into, +{ fn bitand_assign(&mut self, other: &Self) { self.intersect_with(other); } } -impl<'a> BitOr for &'a FixedBitSet { - type Output = FixedBitSet; - fn bitor(self, other: &FixedBitSet) -> FixedBitSet { +impl<'a, K> BitOr for &'a FixedBitSet +where + K: From + Into, +{ + type Output = FixedBitSet; + fn bitor(self, other: &FixedBitSet) -> FixedBitSet { let (short, long) = { - if self.len() <= other.len() { + if self.length <= other.length { (&self.data, &other.data) } else { (&other.data, &self.data) @@ -733,28 +806,41 @@ impl<'a> BitOr for &'a FixedBitSet { for (data, block) in data.iter_mut().zip(short.iter()) { *data |= *block; } - let len = std::cmp::max(self.len(), other.len()); - FixedBitSet { data, length: len } + let len = std::cmp::max(self.length, other.length); + FixedBitSet { + data, + length: len, + kind: PhantomData, + } } } -impl<'a> BitOrAssign for FixedBitSet { +impl<'a, K> BitOrAssign for FixedBitSet +where + K: From + Into, +{ fn bitor_assign(&mut self, other: Self) { self.union_with(&other); } } -impl<'a> BitOrAssign<&Self> for FixedBitSet { +impl<'a, K> BitOrAssign<&Self> for FixedBitSet +where + K: From + Into, +{ fn bitor_assign(&mut self, other: &Self) { self.union_with(other); } } -impl<'a> BitXor for &'a FixedBitSet { - type Output = FixedBitSet; - fn bitxor(self, other: &FixedBitSet) -> FixedBitSet { +impl<'a, K> BitXor for &'a FixedBitSet +where + K: From + Into, +{ + type Output = FixedBitSet; + fn bitxor(self, other: &FixedBitSet) -> FixedBitSet { let (short, long) = { - if self.len() <= other.len() { + if self.length <= other.length { (&self.data, &other.data) } else { (&other.data, &self.data) @@ -764,18 +850,28 @@ impl<'a> BitXor for &'a FixedBitSet { for (data, block) in data.iter_mut().zip(short.iter()) { *data ^= *block; } - let len = std::cmp::max(self.len(), other.len()); - FixedBitSet { data, length: len } + let len = std::cmp::max(self.length, other.length); + FixedBitSet { + data, + length: len, + kind: PhantomData, + } } } -impl<'a> BitXorAssign for FixedBitSet { +impl<'a, K> BitXorAssign for FixedBitSet +where + K: From + Into, +{ fn bitxor_assign(&mut self, other: Self) { self.symmetric_difference_with(&other); } } -impl<'a> BitXorAssign<&Self> for FixedBitSet { +impl<'a, K> BitXorAssign<&Self> for FixedBitSet +where + K: From + Into, +{ fn bitxor_assign(&mut self, other: &Self) { self.symmetric_difference_with(other); } @@ -961,21 +1057,21 @@ fn iter_ones_range() { #[should_panic] #[test] fn count_ones_oob() { - let fb = FixedBitSet::with_capacity(100); + let fb = FixedBitSet::::with_capacity(100); fb.count_ones(90..101); } #[should_panic] #[test] fn count_ones_negative_range() { - let fb = FixedBitSet::with_capacity(100); + let fb = FixedBitSet::::with_capacity(100); fb.count_ones(90..80); } #[test] fn count_ones_panic() { for i in 1..128 { - let fb = FixedBitSet::with_capacity(i); + let fb = FixedBitSet::::with_capacity(i); for j in 0..fb.len() + 1 { for k in j..fb.len() + 1 { assert_eq!(fb.count_ones(j..k), 0); @@ -986,7 +1082,7 @@ fn count_ones_panic() { #[test] fn default() { - let fb = FixedBitSet::default(); + let fb = FixedBitSet::::default(); assert_eq!(fb.len(), 0); } @@ -1448,7 +1544,7 @@ fn bitxor_assign_longer() { #[test] fn op_assign_ref() { - let mut a = FixedBitSet::with_capacity(8); + let mut a = FixedBitSet::::with_capacity(8); let b = FixedBitSet::with_capacity(8); //check that all assign type operators work on references @@ -1571,6 +1667,39 @@ fn from_iterator_ones() { ); } +#[test] +fn other_types() { + use std::convert::TryInto; + struct CustomIndex(pub u16); + + impl From for CustomIndex { + fn from(value: usize) -> Self { + CustomIndex(value.try_into().expect("value too large to fit u16")) + } + } + impl Into for CustomIndex { + fn into(self) -> usize { + self.0.into() + } + } + + let mut fb = FixedBitSet::::with_capacity(40); + fb.insert_range(..10); + fb.insert_range(34..38); + + fb.toggle_range(5..12); + fb.toggle_range(30..); + + for i in 0..40 { + assert_eq!( + fb.contains(i.into()), + i < 5 || 10 <= i && i < 12 || 30 <= i && i < 34 || 38 <= i + ); + } + assert!(!fb.contains(CustomIndex(40))); + assert!(!fb.contains(CustomIndex(64))); +} + #[cfg(feature = "std")] #[test] fn binary_trait() {