Skip to content

Commit 26906d5

Browse files
committed
Merge branch 'features/spline' into dev
2 parents 6804677 + 1e91054 commit 26906d5

File tree

1 file changed

+145
-1
lines changed

1 file changed

+145
-1
lines changed

src/numerical/spline.rs

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,15 @@
194194
//! }
195195
//! ```
196196
//!
197+
//! # B-Spline (incomplete)
198+
//!
199+
//! - `UnitCubicBasis`: Single cubic B-Spline basis function
200+
//! - `CubicBSplineBases`: Uniform Cubic B-Spline basis functions
201+
//!
197202
//! # References
198203
//!
199-
//! * Gary D. Knott, *Interpolating Splines*, Birkhäuser Boston, MA, (2000).
204+
//! - Gary D. Knott, *Interpolating Splines*, Birkhäuser Boston, MA, (2000).
205+
/// - [Wikipedia - Irwin-Hall distribution](https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution#Special_cases)
200206
201207
use self::SplineError::{NotEnoughNodes, NotEqualNodes, NotEqualSlopes, RedundantNodeX};
202208
#[allow(unused_imports)]
@@ -843,3 +849,141 @@ fn quadratic_slopes(x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
843849

844850
Ok(m)
845851
}
852+
853+
// =============================================================================
854+
// B-Spline
855+
// =============================================================================
856+
/// Unit Cubic Basis Function
857+
///
858+
/// # Description
859+
/// Unit cubic basis function from Irwin-Hall distribution (n=4).
860+
/// For general interval, we substitute t = 4 * (x - a) / (b - a).
861+
///
862+
/// # Reference
863+
/// [Wikipedia](https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution#Special_cases)
864+
#[derive(Debug, Copy, Clone)]
865+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
866+
pub struct UnitCubicBasis {
867+
pub x_min: f64,
868+
pub x_max: f64,
869+
pub scale: f64,
870+
}
871+
872+
impl UnitCubicBasis {
873+
pub fn new(x_min: f64, x_max: f64, scale: f64) -> Self {
874+
Self { x_min, x_max, scale }
875+
}
876+
877+
pub fn eval(&self, x: f64) -> f64 {
878+
let t = 4f64 * (x - self.x_min) / (self.x_max - self.x_min);
879+
880+
let result = if (0f64..1f64).contains(&t) {
881+
t.powi(3) / 6f64
882+
} else if (1f64..2f64).contains(&t) {
883+
(-3f64 * t.powi(3) + 12f64 * t.powi(2) - 12f64 * t + 4f64) / 6f64
884+
} else if (2f64..3f64).contains(&t) {
885+
(3f64 * t.powi(3) - 24f64 * t.powi(2) + 60f64 * t - 44f64) / 6f64
886+
} else if (3f64..4f64).contains(&t) {
887+
(4f64 - t).powi(3) / 6f64
888+
} else {
889+
0f64
890+
};
891+
892+
self.scale * result
893+
}
894+
895+
pub fn eval_vec(&self, x: &[f64]) -> Vec<f64> {
896+
x.iter().map(|x| self.eval(*x)).collect()
897+
}
898+
}
899+
900+
/// Uniform Cubic B-Spline basis functions
901+
///
902+
/// # Example
903+
///
904+
/// ```rust
905+
/// use peroxide::fuga::*;
906+
/// use core::ops::Range;
907+
///
908+
/// # #[allow(unused_variables)]
909+
/// fn main() -> anyhow::Result<()> {
910+
/// let cubic_b_spline = CubicBSplineBases::from_interval((0f64, 1f64), 5);
911+
/// let x = linspace(0f64, 1f64, 1000);
912+
/// let y = cubic_b_spline.eval_vec(&x);
913+
///
914+
/// # #[cfg(feature = "plot")] {
915+
/// let mut plt = Plot2D::new();
916+
/// plt.set_domain(x.clone());
917+
///
918+
/// for basis in &cubic_b_spline.bases {
919+
/// plt.insert_image(basis.eval_vec(&x));
920+
/// }
921+
///
922+
/// plt
923+
/// .insert_image(y)
924+
/// .set_xlabel(r"$x$")
925+
/// .set_ylabel(r"$y$")
926+
/// .set_style(PlotStyle::Nature)
927+
/// .tight_layout()
928+
/// .set_dpi(600)
929+
/// .set_path("example_data/cubic_b_spline.png")
930+
/// .savefig()?;
931+
/// # }
932+
/// Ok(())
933+
/// }
934+
/// ```
935+
#[derive(Debug, Clone)]
936+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
937+
pub struct CubicBSplineBases {
938+
pub ranges: Vec<Range<f64>>,
939+
pub bases: Vec<UnitCubicBasis>,
940+
}
941+
942+
impl CubicBSplineBases {
943+
/// Create new Cubic B-Spline basis functions
944+
pub fn new(ranges: Vec<Range<f64>>, bases: Vec<UnitCubicBasis>) -> Self {
945+
Self { ranges, bases }
946+
}
947+
948+
/// Create new Cubic B-Spline basis functions for `[a, b]`
949+
pub fn from_interval((a, b): (f64, f64), num_bases: usize) -> Self {
950+
let nodes = linspace(a, b, num_bases + 4);
951+
let (ranges, bases) = nodes
952+
.iter()
953+
.zip(nodes.iter().skip(4))
954+
.map(|(a, b)| (Range { start: *a, end: *b }, UnitCubicBasis::new(*a, *b, 1f64)))
955+
.unzip();
956+
957+
Self::new(ranges, bases)
958+
}
959+
960+
/// Rescale all basis functions
961+
///
962+
/// # Arguments
963+
/// - `scale_vec` - scale vector
964+
pub fn rescale(&mut self, scale_vec: &[f64]) -> Result<()> {
965+
if scale_vec.len() != self.bases.len() {
966+
bail!("The number of scales should be equal to the number of basis functions");
967+
}
968+
969+
for (basis, scale) in self.bases.iter_mut().zip(scale_vec) {
970+
basis.scale = *scale;
971+
}
972+
973+
Ok(())
974+
}
975+
976+
pub fn eval(&self, x: f64) -> f64 {
977+
self.ranges.iter()
978+
.enumerate()
979+
.filter(|(_, range)| range.contains(&x))
980+
.fold(0f64, |acc, (i, _)| {
981+
let basis = &self.bases[i];
982+
acc + basis.eval(x)
983+
})
984+
}
985+
986+
pub fn eval_vec(&self, x: &[f64]) -> Vec<f64> {
987+
x.iter().map(|x| self.eval(*x)).collect()
988+
}
989+
}

0 commit comments

Comments
 (0)