diff options
Diffstat (limited to 'linux/mean_and_variance.c')
-rw-r--r-- | linux/mean_and_variance.c | 72 |
1 files changed, 41 insertions, 31 deletions
diff --git a/linux/mean_and_variance.c b/linux/mean_and_variance.c index bd08da5f..eb5f2ba0 100644 --- a/linux/mean_and_variance.c +++ b/linux/mean_and_variance.c @@ -43,38 +43,28 @@ #include <linux/mean_and_variance.h> #include <linux/module.h> -/** - * fast_divpow2() - fast approximation for n / (1 << d) - * @n: numerator - * @d: the power of 2 denominator. - * - * note: this rounds towards 0. - */ -s64 fast_divpow2(s64 n, u8 d) +u128_u u128_div(u128_u n, u64 d) { - return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; -} + u128_u r; + u64 rem; + u64 hi = u128_hi(n); + u64 lo = u128_lo(n); + u64 h = hi & ((u64) U32_MAX << 32); + u64 l = (hi & (u64) U32_MAX) << 32; -/** - * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 - * and return it. - * @s1: the mean_and_variance to update. - * @v1: the new sample. - * - * see linked pdf equation 12. - */ -struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1) -{ - return mean_and_variance_update_inlined(s1, v1); + r = u128_shl(u64_to_u128(div64_u64_rem(h, d, &rem)), 64); + r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l + (rem << 32), d, &rem)), 32)); + r = u128_add(r, u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem))); + return r; } -EXPORT_SYMBOL_GPL(mean_and_variance_update); +EXPORT_SYMBOL_GPL(u128_div); /** * mean_and_variance_get_mean() - get mean from @s */ s64 mean_and_variance_get_mean(struct mean_and_variance s) { - return div64_u64(s.sum, s.n); + return s.n ? div64_u64(s.sum, s.n) : 0; } EXPORT_SYMBOL_GPL(mean_and_variance_get_mean); @@ -85,10 +75,14 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_mean); */ u64 mean_and_variance_get_variance(struct mean_and_variance s1) { - u128 s2 = u128_div(s1.sum_squares, s1.n); - u64 s3 = abs(mean_and_variance_get_mean(s1)); + if (s1.n) { + u128_u s2 = u128_div(s1.sum_squares, s1.n); + u64 s3 = abs(mean_and_variance_get_mean(s1)); - return u128_to_u64(u128_sub(s2, u128_square(s3))); + return u128_lo(u128_sub(s2, u128_square(s3))); + } else { + return 0; + } } EXPORT_SYMBOL_GPL(mean_and_variance_get_variance); @@ -109,10 +103,26 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev); * see linked pdf: function derived from equations 140-143 where alpha = 2^w. * values are stored bitshifted for performance and added precision. */ -struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, - s64 x) +void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 x) { - return mean_and_variance_weighted_update_inlined(s1, x); + // previous weighted variance. + u8 w = s->weight; + u64 var_w0 = s->variance; + // new value weighted. + s64 x_w = x << w; + s64 diff_w = x_w - s->mean; + s64 diff = fast_divpow2(diff_w, w); + // new mean weighted. + s64 u_w1 = s->mean + diff; + + if (!s->init) { + s->mean = x_w; + s->variance = 0; + } else { + s->mean = u_w1; + s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w; + } + s->init = true; } EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update); @@ -121,7 +131,7 @@ EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update); */ s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s) { - return fast_divpow2(s.mean, s.w); + return fast_divpow2(s.mean, s.weight); } EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean); @@ -131,7 +141,7 @@ EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean); u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s) { // always positive don't need fast divpow2 - return s.variance >> s.w; + return s.variance >> s.weight; } EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance); |