diff options
Diffstat (limited to 'linux')
-rw-r--r-- | linux/int_sqrt.c | 71 | ||||
-rw-r--r-- | linux/mean_and_variance.c | 178 | ||||
-rw-r--r-- | linux/six.c | 11 |
3 files changed, 257 insertions, 3 deletions
diff --git a/linux/int_sqrt.c b/linux/int_sqrt.c new file mode 100644 index 00000000..a8170bb9 --- /dev/null +++ b/linux/int_sqrt.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2013 Davidlohr Bueso <davidlohr.bueso@hp.com> + * + * Based on the shift-and-subtract algorithm for computing integer + * square root from Guy L. Steele. + */ + +#include <linux/export.h> +#include <linux/bitops.h> +#include <linux/limits.h> +#include <linux/math.h> + +/** + * int_sqrt - computes the integer square root + * @x: integer of which to calculate the sqrt + * + * Computes: floor(sqrt(x)) + */ +unsigned long int_sqrt(unsigned long x) +{ + unsigned long b, m, y = 0; + + if (x <= 1) + return x; + + m = 1UL << (__fls(x) & ~1UL); + while (m != 0) { + b = y + m; + y >>= 1; + + if (x >= b) { + x -= b; + y += m; + } + m >>= 2; + } + + return y; +} +EXPORT_SYMBOL(int_sqrt); + +#if BITS_PER_LONG < 64 +/** + * int_sqrt64 - strongly typed int_sqrt function when minimum 64 bit input + * is expected. + * @x: 64bit integer of which to calculate the sqrt + */ +u32 int_sqrt64(u64 x) +{ + u64 b, m, y = 0; + + if (x <= ULONG_MAX) + return int_sqrt((unsigned long) x); + + m = 1ULL << ((fls64(x) - 1) & ~1ULL); + while (m != 0) { + b = y + m; + y >>= 1; + + if (x >= b) { + x -= b; + y += m; + } + m >>= 2; + } + + return y; +} +EXPORT_SYMBOL(int_sqrt64); +#endif diff --git a/linux/mean_and_variance.c b/linux/mean_and_variance.c new file mode 100644 index 00000000..643e3113 --- /dev/null +++ b/linux/mean_and_variance.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Functions for incremental mean and variance. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published by + * the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * Copyright © 2022 Daniel B. Hill + * + * Author: Daniel B. Hill <daniel@gluo.nz> + * + * Description: + * + * This is includes some incremental algorithms for mean and variance calculation + * + * Derived from the paper: https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf + * + * Create a struct and if it's the weighted variant set the w field (weight = 2^k). + * + * Use mean_and_variance[_weighted]_update() on the struct to update it's state. + * + * Use the mean_and_variance[_weighted]_get_* functions to calculate the mean and variance, some computation + * is deferred to these functions for performance reasons. + * + * see lib/math/mean_and_variance_test.c for examples of usage. + * + * DO NOT access the mean and variance fields of the weighted variants directly. + * DO NOT change the weight after calling update. + */ + +#include <linux/bug.h> +#include <linux/compiler.h> +#include <linux/export.h> +#include <linux/limits.h> +#include <linux/math.h> +#include <linux/math64.h> +#include <linux/mean_and_variance.h> +#include <linux/module.h> +#include <linux/printbuf.h> + + +/** + * fast_divpow2() - fast approximation for n / (1 << d) + * @n: numerator + * @d: the power of 2 denominator. + * + * note: this rounds towards 0. + */ +inline s64 fast_divpow2(s64 n, u8 d) +{ + return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; +} + +/** + * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 + * and return it. + * @s1: the mean_and_variance to update. + * @v1: the new sample. + * + * see linked pdf equation 12. + */ +struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1) +{ + struct mean_and_variance s2; + u64 v2 = abs(v1); + + s2.n = s1.n + 1; + s2.sum = s1.sum + v1; + s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2)); + return s2; +} +EXPORT_SYMBOL_GPL(mean_and_variance_update); + +/** + * mean_and_variance_get_mean() - get mean from @s + */ +s64 mean_and_variance_get_mean(struct mean_and_variance s) +{ + return div64_u64(s.sum, s.n); +} +EXPORT_SYMBOL_GPL(mean_and_variance_get_mean); + +/** + * mean_and_variance_get_variance() - get variance from @s1 + * + * see linked pdf equation 12. + */ +u64 mean_and_variance_get_variance(struct mean_and_variance s1) +{ + u128 s2 = u128_div(s1.sum_squares, s1.n); + u64 s3 = abs(mean_and_variance_get_mean(s1)); + + return u128_to_u64(u128_sub(s2, u128_square(s3))); +} +EXPORT_SYMBOL_GPL(mean_and_variance_get_variance); + +/** + * mean_and_variance_get_stddev() - get standard deviation from @s + */ +u32 mean_and_variance_get_stddev(struct mean_and_variance s) +{ + return int_sqrt64(mean_and_variance_get_variance(s)); +} +EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev); + +/** + * mean_and_variance_weighted_update() - exponentially weighted variant of mean_and_variance_update() + * @s1: .. + * @s2: .. + * + * see linked pdf: function derived from equations 140-143 where alpha = 2^w. + * values are stored bitshifted for performance and added precision. + */ +struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, + s64 x) +{ + struct mean_and_variance_weighted s2; + // previous weighted variance. + u64 var_w0 = s1.variance; + u8 w = s2.w = s1.w; + // new value weighted. + s64 x_w = x << w; + s64 diff_w = x_w - s1.mean; + s64 diff = fast_divpow2(diff_w, w); + // new mean weighted. + s64 u_w1 = s1.mean + diff; + + BUG_ON(w % 2 != 0); + + if (!s1.init) { + s2.mean = x_w; + s2.variance = 0; + } else { + s2.mean = u_w1; + s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w; + } + s2.init = true; + + return s2; +} +EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update); + +/** + * mean_and_variance_weighted_get_mean() - get mean from @s + */ +s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s) +{ + return fast_divpow2(s.mean, s.w); +} +EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean); + +/** + * mean_and_variance_weighted_get_variance() -- get variance from @s + */ +u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s) +{ + // always positive don't need fast divpow2 + return s.variance >> s.w; +} +EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance); + +/** + * mean_and_variance_weighted_get_stddev() - get standard deviation from @s + */ +u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s) +{ + return int_sqrt64(mean_and_variance_weighted_get_variance(s)); +} +EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_stddev); + +MODULE_AUTHOR("Daniel B. Hill"); +MODULE_LICENSE("GPL"); diff --git a/linux/six.c b/linux/six.c index b11660af..39f7ea79 100644 --- a/linux/six.c +++ b/linux/six.c @@ -148,6 +148,14 @@ static int __do_six_trylock_type(struct six_lock *lock, atomic64_add(__SIX_VAL(write_locking, 1), &lock->state.counter); smp_mb__after_atomic(); + } else if (!(lock->state.waiters & (1 << SIX_LOCK_write))) { + atomic64_add(__SIX_VAL(waiters, 1 << SIX_LOCK_write), + &lock->state.counter); + /* + * pairs with barrier after unlock and before checking + * for readers in unlock path + */ + smp_mb__after_atomic(); } ret = !pcpu_read_count(lock); @@ -162,9 +170,6 @@ static int __do_six_trylock_type(struct six_lock *lock, if (ret || try) v -= __SIX_VAL(write_locking, 1); - if (!ret && !try && !(lock->state.waiters & (1 << SIX_LOCK_write))) - v += __SIX_VAL(waiters, 1 << SIX_LOCK_write); - if (try && !ret) { old.v = atomic64_add_return(v, &lock->state.counter); if (old.waiters & (1 << SIX_LOCK_read)) |