summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Hill <daniel@gluo.nz>2022-08-06 14:48:49 +1200
committerDaniel Hill <daniel@gluo.nz>2022-09-02 17:55:36 +1200
commitb30d8e5503da071512123b26e00d3c777530a540 (patch)
tree7274205b8bfd4a36143ecd842d3b32b4778f0db6
parentd07538c3c60ac36ddc776b563fb8b0810622cc25 (diff)
lib: add mean and variance module.
This module provides a fast 64bit implementation of basic statistics functions, including mean, variance and standard deviation in both weighted and unweighted variants, the unweighted variant has a 32bit limitation per sample to prevent overflow when squaring. Signed-off-by: Daniel Hill <daniel@gluo.nz>
-rw-r--r--include/linux/mean_and_variance.h37
-rw-r--r--lib/math/Kconfig11
-rw-r--r--lib/math/Makefile2
-rw-r--r--lib/math/mean_and_variance.c176
-rw-r--r--lib/math/mean_and_variance_test.c161
5 files changed, 387 insertions, 0 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h
new file mode 100644
index 000000000000..280a16dea789
--- /dev/null
+++ b/include/linux/mean_and_variance.h
@@ -0,0 +1,37 @@
+// SPDX-License-Identifier: GPL-2.0
+#ifndef STATS_H_
+#define STATS_H_
+
+#include <linux/types.h>
+
+#define SQRT_U64_MAX 4294967295ULL
+
+struct mean_and_variance {
+ s64 n;
+ s64 sum;
+ u64 sum_squares;
+};
+
+/* expontentially weighted variant */
+struct mean_and_variance_ewm {
+ bool init;
+ u8 w;
+ s64 mean;
+ u64 variance;
+};
+
+#ifdef CONFIG_MEAN_AND_VARIANCE_UNIT_TEST
+s64 fast_divpow2(s64 n, u8 d);
+#endif
+
+struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
+ s64 get_mean(struct mean_and_variance s);
+ u64 get_variance(struct mean_and_variance s1);
+ u32 get_stddev(struct mean_and_variance s);
+
+struct mean_and_variance_ewm mean_and_variance_ewm_update(struct mean_and_variance_ewm s1, s64 v1);
+ s64 get_ewm_mean(struct mean_and_variance_ewm s);
+ u64 get_ewm_variance(struct mean_and_variance_ewm s);
+ u32 get_ewm_stddev(struct mean_and_variance_ewm s);
+
+#endif // STATS_H_
diff --git a/lib/math/Kconfig b/lib/math/Kconfig
index 0634b428d0cb..33aa9afc1ce8 100644
--- a/lib/math/Kconfig
+++ b/lib/math/Kconfig
@@ -15,3 +15,14 @@ config PRIME_NUMBERS
config RATIONAL
tristate
+
+config MEAN_AND_VARIANCE
+ tristate "fast incremental integer mean and vairance module"
+ help
+ This option provides functions for calculating mean and standard
+ deviation incrementally, standard and expotentially weighted variants"
+
+config MEAN_AND_VARIANCE_UNIT_TEST
+ tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS
+ depends on MEAN_AND_VARIANCE && KUNIT
+ default KUNIT_ALL_TESTS
diff --git a/lib/math/Makefile b/lib/math/Makefile
index bfac26ddfc22..2ef1487e01c2 100644
--- a/lib/math/Makefile
+++ b/lib/math/Makefile
@@ -4,6 +4,8 @@ obj-y += div64.o gcd.o lcm.o int_pow.o int_sqrt.o reciprocal_div.o
obj-$(CONFIG_CORDIC) += cordic.o
obj-$(CONFIG_PRIME_NUMBERS) += prime_numbers.o
obj-$(CONFIG_RATIONAL) += rational.o
+obj-$(CONFIG_MEAN_AND_VARIANCE) += mean_and_variance.o
obj-$(CONFIG_TEST_DIV64) += test_div64.o
obj-$(CONFIG_RATIONAL_KUNIT_TEST) += rational-test.o
+obj-$(CONFIG_MEAN_AND_VARIANCE_UNIT_TEST) += mean_and_variance_test.o
diff --git a/lib/math/mean_and_variance.c b/lib/math/mean_and_variance.c
new file mode 100644
index 000000000000..08c1b3c74198
--- /dev/null
+++ b/lib/math/mean_and_variance.c
@@ -0,0 +1,176 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Functions for incremental mean and variance.
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 as published by
+ * the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * Copyright © 2022 Daniel B. Hill
+ *
+ * Author: Daniel B. Hill <daniel@gluo.nz>
+ *
+ * Description:
+ *
+ * This is includes some incremental algorthims for mean and variance calculation
+ *
+ * Derived from the paper: https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
+ *
+ * Create a struct and if it's the ewm variant set the w field (weight = 2^k).
+ *
+ * Use mean_and_variance[_ewm]_update() on the struct to update it's state.
+ *
+ * Use the get_* functions to calculate the mean and variance, some computation
+ * is deffered to these functions for peformance reasons.
+ *
+ * see lib/math/mean_and_variance_test.c for examples of usage.
+ *
+ * DO NOT access the mean and variance fields of the ewm variants directly.
+ * DO NOT change the weight after calling update.
+ */
+
+#include <linux/bug.h>
+#include <linux/compiler.h>
+#include <linux/export.h>
+#include <linux/limits.h>
+#include <linux/math.h>
+#include <linux/math64.h>
+#include <linux/mean_and_variance.h>
+#include <linux/module.h>
+#include <linux/printbuf.h>
+
+
+inline s64 fast_divpow2(s64 n, u8 d)
+{
+ return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; // + (n < 0 ? 1 : 0);
+}
+
+/**
+ * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
+ * and return it.
+ * @s1: the mean_and_variance to update.
+ * @v1: the new sample.
+ *
+ * see linked pdf equation 12.
+ **/
+inline struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1)
+{
+ struct mean_and_variance s2;
+ u64 v2 = abs(v1);
+
+ if (v2 > SQRT_U64_MAX) {
+ v2 = SQRT_U64_MAX;
+ WARN(true, "stats overflow! %lld^2 > U64_MAX", v1);
+ }
+
+ s2.n = s1.n + 1;
+ s2.sum = s1.sum + v1;
+ s2.sum_squares = s1.sum_squares + v2*v2;
+ return s2;
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_update);
+
+/**
+ * get_mean() - get mean from @s
+ */
+inline s64 get_mean(struct mean_and_variance s)
+{
+ return s.sum / s.n;
+}
+EXPORT_SYMBOL_GPL(get_mean);
+
+/**
+ * get_variance() - get variance from @s1
+ *
+ * see linked pdf equation 12.
+ */
+inline u64 get_variance(struct mean_and_variance s1)
+{
+ u64 s2 = s1.sum_squares / s1.n;
+ u64 s3 = abs(get_mean(s1));
+
+ WARN(s3 > SQRT_U64_MAX, "stats overflow %lld ^2 > S64_MAX", s3);
+ return s2 - s3*s3;
+}
+EXPORT_SYMBOL_GPL(get_variance);
+
+/**
+ * get_stddev() - get standard deviation from @s
+ */
+inline u32 get_stddev(struct mean_and_variance s)
+{
+ return int_sqrt64(get_variance(s));
+}
+EXPORT_SYMBOL_GPL(get_stddev);
+
+/**
+ * mean_and_variance_evm_update() - expontentially weighted variant of mean_and_variance_update()
+ * @s1: ..
+ * @s2: ..
+ *
+ * see linked pdf: function derived from equations 140-143 where alpha = 2^w.
+ * values are stored bitshifted for performance and added precision.
+ */
+inline struct mean_and_variance_ewm mean_and_variance_ewm_update(struct mean_and_variance_ewm s1, s64 v1)
+{
+ struct mean_and_variance_ewm s2;
+ s64 m = s1.mean;
+ u64 var = s1.variance;
+ u8 w = s2.w = s1.w;
+ s64 v2 = v1 << w;
+ s64 d1 = (v2 - m);
+ s64 d2 = fast_divpow2(d1, w);
+ u64 d3 = (d1*d1) >> w;
+
+ if (!s1.init) {
+ s2.mean = v2;
+ s2.variance = 0;
+ } else {
+ s2.mean = m + d2;
+ s2.variance = var + ((d3 - (d3 >> w) - var) >> w);
+ }
+ s2.init = true;
+
+ #ifdef CONFIG_STATS_UNIT_TEST
+ printk(KERN_DEBUG "v1 = %lld, v2 = %lld, d1 = %lld, d2 = %lld, d3 = %llu, m = %lld, var = %llu",
+ v1, v2, d1, d2, d3, s2.mean, s2.variance);
+ #endif
+ return s2;
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_ewm_update);
+
+/**
+ * get_ewm_mean() - get mean from @s
+ */
+inline s64 get_ewm_mean(struct mean_and_variance_ewm s)
+{
+ return fast_divpow2(s.mean, s.w);
+}
+EXPORT_SYMBOL_GPL(get_ewm_mean);
+
+/**
+ * get_ewm_variance() -- get variance from @s
+ */
+inline u64 get_ewm_variance(struct mean_and_variance_ewm s)
+{
+ // always positive don't need fast divpow2
+ return s.variance >> s.w;
+}
+EXPORT_SYMBOL_GPL(get_ewm_variance);
+
+/**
+ * get_ewm_stddev() - get standard deviation from @s
+ */
+inline u32 get_ewm_stddev(struct mean_and_variance_ewm s)
+{
+ return int_sqrt64(get_ewm_variance(s));
+}
+EXPORT_SYMBOL_GPL(get_ewm_stddev);
+
+MODULE_AUTHOR("Daniel B. Hill");
+MODULE_LICENSE("GPL");
diff --git a/lib/math/mean_and_variance_test.c b/lib/math/mean_and_variance_test.c
new file mode 100644
index 000000000000..ec32dad2246f
--- /dev/null
+++ b/lib/math/mean_and_variance_test.c
@@ -0,0 +1,161 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <kunit/test.h>
+#include <linux/stats.h>
+
+#define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX)
+
+static void stats_basic_test(struct kunit *test)
+{
+ struct mean_and_variance s = {};
+
+ s = stats_update(s, 2);
+ s = stats_update(s, 2);
+
+ KUNIT_EXPECT_EQ(test, stats_mean(s), 2);
+ KUNIT_EXPECT_EQ(test, stats_variance(s), 0);
+ KUNIT_EXPECT_EQ(test, s.n, 2);
+
+ s = stats_update(s, 4);
+ s = stats_update(s, 4);
+
+ KUNIT_EXPECT_EQ(test, stats_mean(s), 3);
+ KUNIT_EXPECT_EQ(test, stats_variance(s), 1);
+ KUNIT_EXPECT_EQ(test, s.n, 4);
+
+ /*
+ * Test overflow bounds
+ */
+ s = (struct mean_and_variance){};
+
+ s = stats_update(s, SQRT_U64_MAX);
+
+ KUNIT_EXPECT_EQ_MSG(test,
+ s.sum_squares,
+ MAX_SQR,
+ "%llu == %llu, sqrt: %llu == %llu",
+ s.sum_squares,
+ MAX_SQR,
+ int_sqrt64(s.sum_squares),
+ SQRT_U64_MAX);
+
+ s = (struct mean_and_variance){};
+
+ s = stats_update(s, -(s64)SQRT_U64_MAX);
+
+ KUNIT_EXPECT_EQ_MSG(test,
+ s.sum_squares,
+ MAX_SQR,
+ "%llu == %llu, sqrt: %llu == %llu",
+ s.sum_squares,
+ MAX_SQR,
+ int_sqrt64(s.sum_squares),
+ SQRT_U64_MAX);
+
+ s = (struct mean_and_variance){};
+
+ s = stats_update(s, (SQRT_U64_MAX + 1));
+
+ KUNIT_EXPECT_LT(test, s.sum_squares, MAX_SQR);
+
+ s = (struct mean_and_variance){};
+
+ s = stats_update(s, (-(s64)SQRT_U64_MAX) - 1);
+
+ KUNIT_EXPECT_LT(test, s.sum_squares, MAX_SQR);
+}
+
+/*
+** Test values computed using a spreadsheet from the psuedocode at the bottom:
+** https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
+ */
+
+static void stats_ewm_test(struct kunit *test)
+{
+ struct mean_and_variance_ewm s = {};
+
+ s.w = 2;
+
+ s = stats_ewm_update(s, 10);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), 10);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 0);
+
+ s = stats_ewm_update(s, 20);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), 12);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 18);
+
+ s = stats_ewm_update(s, 30);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), 16);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 71);
+
+ s = (struct mean_and_variance_ewm){};
+ s.w = 2;
+
+ s = stats_ewm_update(s, -10);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), -10);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 0);
+
+ s = stats_ewm_update(s, -20);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), -12);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 18);
+
+ s = stats_ewm_update(s, -30);
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), -16);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 71);
+
+}
+
+static void stats_ewm_advanced_test(struct kunit *test)
+{
+ struct mean_and_variance_ewm s = {};
+ s64 i;
+
+ s.w = 8;
+ for (i = 10; i <= 100; i += 10)
+ s = stats_ewm_update(s, i);
+
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), 11);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 107);
+
+ s = (struct mean_and_variance_ewm){};
+
+ s.w = 8;
+ for (i = -10; i >= -100; i -= 10)
+ s = stats_ewm_update(s, i);
+
+ KUNIT_EXPECT_EQ(test, stats_ewm_mean(s), -11);
+ KUNIT_EXPECT_EQ(test, stats_ewm_variance(s), 107);
+
+}
+
+static void stats_fast_divpow2(struct kunit *test)
+{
+ s64 i;
+ u8 d;
+
+ for (i = 0; i < 100; i++) {
+ d = 0;
+ KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d));
+ KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d));
+ for (d = 1; d < 32; d++) {
+ KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)),
+ div_u64(i, 1 << d), "%lld %u", i, d);
+ KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)),
+ div_u64(i, 1 << d), "%lld %u", -i, d);
+ }
+ }
+}
+
+static struct kunit_case stats_test_cases[] = {
+ KUNIT_CASE(stats_basic_test),
+ KUNIT_CASE(stats_ewm_test),
+ KUNIT_CASE(stats_ewm_advanced_test),
+ KUNIT_CASE(stats_fast_divpow2),
+ {}
+};
+
+static struct kunit_suite stats_test_suite = {
+.name = "statistics",
+.test_cases = stats_test_cases
+};
+
+kunit_test_suite(stats_test_suite);