diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 192 |
1 files changed, 132 insertions, 60 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 311cec8e533d..17e8667917aa 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -55,10 +55,14 @@ enum { static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex); static LIST_HEAD(device_list); -static DEFINE_MUTEX(device_mutex); +static DEFINE_SPINLOCK(device_spinlock); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + struct proto *base); static void update_sk_prot(struct sock *sk, struct tls_context *ctx) { @@ -205,23 +209,9 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } -int tls_push_pending_closed_record(struct sock *sk, - struct tls_context *tls_ctx, - int flags, long *timeo) -{ - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - - if (tls_is_partially_sent_record(tls_ctx) || - !list_empty(&ctx->tx_list)) - return tls_tx_records(sk, flags); - else - return tls_ctx->push_pending_record(sk, flags); -} - static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* If in_tcp_sendpages call lower protocol write space handler * to ensure we wake up any waiting operations there. For example @@ -232,14 +222,12 @@ static void tls_write_space(struct sock *sk) return; } - /* Schedule the transmission if tx list is ready */ - if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { - /* Schedule the transmission */ - if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) - schedule_delayed_work(&tx_ctx->tx_work.work, 0); - } - - ctx->sk_write_space(sk); +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf == TLS_HW) + tls_device_write_space(sk, ctx); + else +#endif + tls_sw_write_space(sk, ctx); } static void tls_ctx_free(struct tls_context *ctx) @@ -262,8 +250,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || - (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) { + if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) + goto skip_tx_cleanup; + + if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { free_ctx = true; goto skip_tx_cleanup; } @@ -366,6 +356,30 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, rc = -EFAULT; break; } + case TLS_CIPHER_AES_GCM_256: { + struct tls12_crypto_info_aes_gcm_256 * + crypto_info_aes_gcm_256 = + container_of(crypto_info, + struct tls12_crypto_info_aes_gcm_256, + info); + + if (len != sizeof(*crypto_info_aes_gcm_256)) { + rc = -EINVAL; + goto out; + } + lock_sock(sk); + memcpy(crypto_info_aes_gcm_256->iv, + ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, + TLS_CIPHER_AES_GCM_256_IV_SIZE); + memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq, + TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); + release_sock(sk); + if (copy_to_user(optval, + crypto_info_aes_gcm_256, + sizeof(*crypto_info_aes_gcm_256))) + rc = -EFAULT; + break; + } default: rc = -EINVAL; } @@ -405,7 +419,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, unsigned int optlen, int tx) { struct tls_crypto_info *crypto_info; + struct tls_crypto_info *alt_crypto_info; struct tls_context *ctx = tls_get_ctx(sk); + size_t optsize; int rc = 0; int conf; @@ -414,10 +430,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto out; } - if (tx) + if (tx) { crypto_info = &ctx->crypto_send.info; - else + alt_crypto_info = &ctx->crypto_recv.info; + } else { crypto_info = &ctx->crypto_recv.info; + alt_crypto_info = &ctx->crypto_send.info; + } /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { @@ -432,14 +451,28 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } /* check version */ - if (crypto_info->version != TLS_1_2_VERSION) { + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) { rc = -ENOTSUPP; goto err_crypto_info; } + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) { + rc = -EINVAL; + goto err_crypto_info; + } + } + switch (crypto_info->cipher_type) { - case TLS_CIPHER_AES_GCM_128: { - if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { + case TLS_CIPHER_AES_GCM_128: + case TLS_CIPHER_AES_GCM_256: { + optsize = crypto_info->cipher_type == TLS_CIPHER_AES_GCM_128 ? + sizeof(struct tls12_crypto_info_aes_gcm_128) : + sizeof(struct tls12_crypto_info_aes_gcm_256); + if (optlen != optsize) { rc = -EINVAL; goto err_crypto_info; } @@ -538,39 +571,84 @@ static struct tls_context *create_ctx(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; - ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); if (!ctx) return NULL; icsk->icsk_ulp_data = ctx; + ctx->setsockopt = sk->sk_prot->setsockopt; + ctx->getsockopt = sk->sk_prot->getsockopt; + ctx->sk_proto_close = sk->sk_prot->close; return ctx; } +static void tls_build_proto(struct sock *sk) +{ + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ + if (ip_ver == TLSV6 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { + mutex_lock(&tcpv6_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv6_prot)) { + build_protos(tls_prots[TLSV6], sk->sk_prot); + smp_store_release(&saved_tcpv6_prot, sk->sk_prot); + } + mutex_unlock(&tcpv6_prot_mutex); + } + + if (ip_ver == TLSV4 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], sk->sk_prot); + smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } +} + +static void tls_hw_sk_destruct(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + ctx->sk_destruct(sk); + /* Free ctx */ + kfree(ctx); + icsk->icsk_ulp_data = NULL; +} + static int tls_hw_prot(struct sock *sk) { struct tls_context *ctx; struct tls_device *dev; int rc = 0; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { if (dev->feature && dev->feature(dev)) { ctx = create_ctx(sk); if (!ctx) goto out; + spin_unlock_bh(&device_spinlock); + tls_build_proto(sk); ctx->hash = sk->sk_prot->hash; ctx->unhash = sk->sk_prot->unhash; ctx->sk_proto_close = sk->sk_prot->close; + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_hw_sk_destruct; ctx->rx_conf = TLS_HW_RECORD; ctx->tx_conf = TLS_HW_RECORD; update_sk_prot(sk, ctx); + spin_lock_bh(&device_spinlock); rc = 1; break; } } out: - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); return rc; } @@ -579,12 +657,17 @@ static void tls_hw_unhash(struct sock *sk) struct tls_context *ctx = tls_get_ctx(sk); struct tls_device *dev; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->unhash) + if (dev->unhash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); dev->unhash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); ctx->unhash(sk); } @@ -595,12 +678,17 @@ static int tls_hw_hash(struct sock *sk) int err; err = ctx->hash(sk); - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->hash) + if (dev->hash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); err |= dev->hash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); if (err) tls_hw_unhash(sk); @@ -653,7 +741,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], static int tls_init(struct sock *sk) { - int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; struct tls_context *ctx; int rc = 0; @@ -675,21 +762,8 @@ static int tls_init(struct sock *sk) rc = -ENOMEM; goto out; } - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; - - /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ - if (ip_ver == TLSV6 && - unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { - mutex_lock(&tcpv6_prot_mutex); - if (likely(sk->sk_prot != saved_tcpv6_prot)) { - build_protos(tls_prots[TLSV6], sk->sk_prot); - smp_store_release(&saved_tcpv6_prot, sk->sk_prot); - } - mutex_unlock(&tcpv6_prot_mutex); - } + tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); @@ -699,17 +773,17 @@ out: void tls_register_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_add_tail(&device->dev_list, &device_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_register_device); void tls_unregister_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_del(&device->dev_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_unregister_device); @@ -721,8 +795,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { static int __init tls_register(void) { - build_protos(tls_prots[TLSV4], &tcp_prot); - tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; |