diff options
Diffstat (limited to 'tools/testing/selftests/bpf/test_sockmap.c')
-rw-r--r-- | tools/testing/selftests/bpf/test_sockmap.c | 145 |
1 files changed, 120 insertions, 25 deletions
diff --git a/tools/testing/selftests/bpf/test_sockmap.c b/tools/testing/selftests/bpf/test_sockmap.c index 0c7d9e556b47..10a5fa83c75a 100644 --- a/tools/testing/selftests/bpf/test_sockmap.c +++ b/tools/testing/selftests/bpf/test_sockmap.c @@ -71,6 +71,7 @@ int txmsg_start; int txmsg_end; int txmsg_ingress; int txmsg_skb; +int ktls; static const struct option long_options[] = { {"help", no_argument, NULL, 'h' }, @@ -92,6 +93,7 @@ static const struct option long_options[] = { {"txmsg_end", required_argument, NULL, 'e'}, {"txmsg_ingress", no_argument, &txmsg_ingress, 1 }, {"txmsg_skb", no_argument, &txmsg_skb, 1 }, + {"ktls", no_argument, &ktls, 1 }, {0, 0, NULL, 0 } }; @@ -112,6 +114,76 @@ static void usage(char *argv[]) printf("\n"); } +#define TCP_ULP 31 +#define TLS_TX 1 +#define TLS_RX 2 +#include <linux/tls.h> + +char *sock_to_string(int s) +{ + if (s == c1) + return "client1"; + else if (s == c2) + return "client2"; + else if (s == s1) + return "server1"; + else if (s == s2) + return "server2"; + else if (s == p1) + return "peer1"; + else if (s == p2) + return "peer2"; + else + return "unknown"; +} + +static int sockmap_init_ktls(int verbose, int s) +{ + struct tls12_crypto_info_aes_gcm_128 tls_tx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + struct tls12_crypto_info_aes_gcm_128 tls_rx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + int so_buf = 6553500; + int err; + + err = setsockopt(s, 6, TCP_ULP, "tls", sizeof("tls")); + if (err) { + fprintf(stderr, "setsockopt: TCP_ULP(%s) failed with error %i\n", sock_to_string(s), err); + return -EINVAL; + } + err = setsockopt(s, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx)); + if (err) { + fprintf(stderr, "setsockopt: TLS_TX(%s) failed with error %i\n", sock_to_string(s), err); + return -EINVAL; + } + err = setsockopt(s, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx)); + if (err) { + fprintf(stderr, "setsockopt: TLS_RX(%s) failed with error %i\n", sock_to_string(s), err); + return -EINVAL; + } + err = setsockopt(s, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf)); + if (err) { + fprintf(stderr, "setsockopt: (%s) failed sndbuf with error %i\n", sock_to_string(s), err); + return -EINVAL; + } + err = setsockopt(s, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf)); + if (err) { + fprintf(stderr, "setsockopt: (%s) failed rcvbuf with error %i\n", sock_to_string(s), err); + return -EINVAL; + } + + if (verbose) + fprintf(stdout, "socket(%s) kTLS enabled\n", sock_to_string(s)); + return 0; +} static int sockmap_init_sockets(int verbose) { int i, err, one = 1; @@ -456,6 +528,21 @@ static int sendmsg_test(struct sockmap_options *opt) else rx_fd = p2; + if (ktls) { + /* Redirecting into non-TLS socket which sends into a TLS + * socket is not a valid test. So in this case lets not + * enable kTLS but still run the test. + */ + if (!txmsg_redir || (txmsg_redir && txmsg_ingress)) { + err = sockmap_init_ktls(opt->verbose, rx_fd); + if (err) + return err; + } + err = sockmap_init_ktls(opt->verbose, c1); + if (err) + return err; + } + rxpid = fork(); if (rxpid == 0) { if (opt->drop_expected) @@ -469,8 +556,6 @@ static int sendmsg_test(struct sockmap_options *opt) fprintf(stderr, "msg_loop_rx: iov_count %i iov_buf %i cnt %i err %i\n", iov_count, iov_buf, cnt, err); - shutdown(p2, SHUT_RDWR); - shutdown(p1, SHUT_RDWR); if (s.end.tv_sec - s.start.tv_sec) { sent_Bps = sentBps(s); recvd_Bps = recvdBps(s); @@ -500,7 +585,6 @@ static int sendmsg_test(struct sockmap_options *opt) fprintf(stderr, "msg_loop_tx: iov_count %i iov_buf %i cnt %i err %i\n", iov_count, iov_buf, cnt, err); - shutdown(c1, SHUT_RDWR); if (s.end.tv_sec - s.start.tv_sec) { sent_Bps = sentBps(s); recvd_Bps = recvdBps(s); @@ -910,6 +994,8 @@ static void test_options(char *options) strncat(options, "ingress,", OPTSTRING); if (txmsg_skb) strncat(options, "skb,", OPTSTRING); + if (ktls) + strncat(options, "ktls,", OPTSTRING); } static int __test_exec(int cgrp, int test, struct sockmap_options *opt) @@ -1348,9 +1434,9 @@ static int populate_progs(char *bpf_file) return 0; } -static int __test_suite(char *bpf_file) +static int __test_suite(int cg_fd, char *bpf_file) { - int cg_fd, err; + int err, cleanup = cg_fd; err = populate_progs(bpf_file); if (err < 0) { @@ -1358,22 +1444,24 @@ static int __test_suite(char *bpf_file) return err; } - if (setup_cgroup_environment()) { - fprintf(stderr, "ERROR: cgroup env failed\n"); - return -EINVAL; - } - - cg_fd = create_and_get_cgroup(CG_PATH); if (cg_fd < 0) { - fprintf(stderr, - "ERROR: (%i) open cg path failed: %s\n", - cg_fd, optarg); - return cg_fd; - } + if (setup_cgroup_environment()) { + fprintf(stderr, "ERROR: cgroup env failed\n"); + return -EINVAL; + } - if (join_cgroup(CG_PATH)) { - fprintf(stderr, "ERROR: failed to join cgroup\n"); - return -EINVAL; + cg_fd = create_and_get_cgroup(CG_PATH); + if (cg_fd < 0) { + fprintf(stderr, + "ERROR: (%i) open cg path failed: %s\n", + cg_fd, optarg); + return cg_fd; + } + + if (join_cgroup(CG_PATH)) { + fprintf(stderr, "ERROR: failed to join cgroup\n"); + return -EINVAL; + } } /* Tests basic commands and APIs with range of iov values */ @@ -1394,20 +1482,24 @@ static int __test_suite(char *bpf_file) out: printf("Summary: %i PASSED %i FAILED\n", passed, failed); - cleanup_cgroup_environment(); - close(cg_fd); + if (cleanup < 0) { + cleanup_cgroup_environment(); + close(cg_fd); + } return err; } -static int test_suite(void) +static int test_suite(int cg_fd) { int err; - err = __test_suite(BPF_SOCKMAP_FILENAME); + err = __test_suite(cg_fd, BPF_SOCKMAP_FILENAME); if (err) goto out; - err = __test_suite(BPF_SOCKHASH_FILENAME); + err = __test_suite(cg_fd, BPF_SOCKHASH_FILENAME); out: + if (cg_fd > -1) + close(cg_fd); return err; } @@ -1420,7 +1512,7 @@ int main(int argc, char **argv) int test = PING_PONG; if (argc < 2) - return test_suite(); + return test_suite(-1); while ((opt = getopt_long(argc, argv, ":dhvc:r:i:l:t:", long_options, &longindex)) != -1) { @@ -1486,6 +1578,9 @@ int main(int argc, char **argv) } } + if (argc <= 3 && cg_fd) + return test_suite(cg_fd); + if (!cg_fd) { fprintf(stderr, "%s requires cgroup option: --cgroup <path>\n", argv[0]); |