summaryrefslogtreecommitdiff
path: root/net/tls/tls_device.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_device.c')
-rw-r--r--net/tls/tls_device.c107
1 files changed, 71 insertions, 36 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index e3e6cf75aa03..a03d66046ca3 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -71,7 +71,13 @@ static void tls_device_tx_del_task(struct work_struct *work)
struct tls_offload_context_tx *offload_ctx =
container_of(work, struct tls_offload_context_tx, destruct_work);
struct tls_context *ctx = offload_ctx->ctx;
- struct net_device *netdev = ctx->netdev;
+ struct net_device *netdev;
+
+ /* Safe, because this is the destroy flow, refcount is 0, so
+ * tls_device_down can't store this field in parallel.
+ */
+ netdev = rcu_dereference_protected(ctx->netdev,
+ !refcount_read(&ctx->refcount));
netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
dev_put(netdev);
@@ -81,6 +87,7 @@ static void tls_device_tx_del_task(struct work_struct *work)
static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
{
+ struct net_device *netdev;
unsigned long flags;
bool async_cleanup;
@@ -91,7 +98,14 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
}
list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
- async_cleanup = ctx->netdev && ctx->tx_conf == TLS_HW;
+
+ /* Safe, because this is the destroy flow, refcount is 0, so
+ * tls_device_down can't store this field in parallel.
+ */
+ netdev = rcu_dereference_protected(ctx->netdev,
+ !refcount_read(&ctx->refcount));
+
+ async_cleanup = netdev && ctx->tx_conf == TLS_HW;
if (async_cleanup) {
struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
@@ -229,7 +243,8 @@ static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
down_read(&device_offload_lock);
- netdev = tls_ctx->netdev;
+ netdev = rcu_dereference_protected(tls_ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
if (netdev)
err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
rcd_sn,
@@ -710,7 +725,7 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx,
trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
rcu_read_lock();
- netdev = READ_ONCE(tls_ctx->netdev);
+ netdev = rcu_dereference(tls_ctx->netdev);
if (netdev)
netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
TLS_OFFLOAD_CTX_DIR_RX);
@@ -887,17 +902,28 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
}
static int
-tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
+tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
{
+ struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
+ const struct tls_cipher_size_desc *cipher_sz;
int err, offset, copy, data_len, pos;
struct sk_buff *skb, *skb_iter;
struct scatterlist sg[1];
struct strp_msg *rxm;
char *orig_buf, *buf;
+ switch (tls_ctx->crypto_recv.info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ case TLS_CIPHER_AES_GCM_256:
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_recv.info.cipher_type];
+
rxm = strp_msg(tls_strp_msg(sw_ctx));
- orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+ orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv,
+ sk->sk_allocation);
if (!orig_buf)
return -ENOMEM;
buf = orig_buf;
@@ -912,10 +938,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
sg_init_table(sg, 1);
sg_set_buf(&sg[0], buf,
- rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE);
- err = skb_copy_bits(skb, offset, buf,
- TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv);
+ err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_sz->iv);
if (err)
goto free_buf;
@@ -926,7 +950,7 @@ tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
else
err = 0;
- data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ data_len = rxm->full_len - cipher_sz->tag;
if (skb_pagelen(skb) > offset) {
copy = min_t(int, skb_pagelen(skb) - offset, data_len);
@@ -984,11 +1008,17 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
int is_decrypted = skb->decrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter;
+ int left;
+ left = rxm->full_len - skb->len;
/* Check if all the data is decrypted already */
- skb_walk_frags(skb, skb_iter) {
+ skb_iter = skb_shinfo(skb)->frag_list;
+ while (skb_iter && left > 0) {
is_decrypted &= skb_iter->decrypted;
is_encrypted &= !skb_iter->decrypted;
+
+ left -= skb_iter->len;
+ skb_iter = skb_iter->next;
}
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
@@ -1003,7 +1033,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
* likely have initial fragments decrypted, and final ones not
* decrypted. We need to reencrypt that single SKB.
*/
- return tls_device_reencrypt(sk, sw_ctx);
+ return tls_device_reencrypt(sk, tls_ctx);
}
/* Return immediately if the record is either entirely plaintext or
@@ -1020,7 +1050,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
}
ctx->resync_nh_reset = 1;
- return tls_device_reencrypt(sk, sw_ctx);
+ return tls_device_reencrypt(sk, tls_ctx);
}
static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
@@ -1029,7 +1059,7 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
if (sk->sk_destruct != tls_device_sk_destruct) {
refcount_set(&ctx->refcount, 1);
dev_hold(netdev);
- ctx->netdev = netdev;
+ RCU_INIT_POINTER(ctx->netdev, netdev);
spin_lock_irq(&tls_device_lock);
list_add_tail(&ctx->list, &tls_device_list);
spin_unlock_irq(&tls_device_lock);
@@ -1041,9 +1071,9 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
- u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ const struct tls_cipher_size_desc *cipher_sz;
struct tls_record_info *start_marker_record;
struct tls_offload_context_tx *offload_ctx;
struct tls_crypto_info *crypto_info;
@@ -1078,44 +1108,44 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128:
- nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
- tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
- iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
- rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
- salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
rec_seq =
((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
break;
+ case TLS_CIPHER_AES_GCM_256:
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+ break;
default:
rc = -EINVAL;
goto release_netdev;
}
+ cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
/* Sanity-check the rec_seq_size for stack allocations */
- if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+ if (cipher_sz->rec_seq > TLS_MAX_REC_SEQ_SIZE) {
rc = -EINVAL;
goto release_netdev;
}
prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type;
- prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
- prot->tag_size = tag_size;
+ prot->prepend_size = TLS_HEADER_SIZE + cipher_sz->iv;
+ prot->tag_size = cipher_sz->tag;
prot->overhead_size = prot->prepend_size + prot->tag_size;
- prot->iv_size = iv_size;
- prot->salt_size = salt_size;
- ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- GFP_KERNEL);
+ prot->iv_size = cipher_sz->iv;
+ prot->salt_size = cipher_sz->salt;
+ ctx->tx.iv = kmalloc(cipher_sz->iv + cipher_sz->salt, GFP_KERNEL);
if (!ctx->tx.iv) {
rc = -ENOMEM;
goto release_netdev;
}
- memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+ memcpy(ctx->tx.iv + cipher_sz->salt, iv, cipher_sz->iv);
- prot->rec_seq_size = rec_seq_size;
- ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
+ prot->rec_seq_size = cipher_sz->rec_seq;
+ ctx->tx.rec_seq = kmemdup(rec_seq, cipher_sz->rec_seq, GFP_KERNEL);
if (!ctx->tx.rec_seq) {
rc = -ENOMEM;
goto free_iv;
@@ -1300,7 +1330,8 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
struct net_device *netdev;
down_read(&device_offload_lock);
- netdev = tls_ctx->netdev;
+ netdev = rcu_dereference_protected(tls_ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
if (!netdev)
goto out;
@@ -1309,7 +1340,7 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
if (tls_ctx->tx_conf != TLS_HW) {
dev_put(netdev);
- tls_ctx->netdev = NULL;
+ rcu_assign_pointer(tls_ctx->netdev, NULL);
} else {
set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
}
@@ -1329,7 +1360,11 @@ static int tls_device_down(struct net_device *netdev)
spin_lock_irqsave(&tls_device_lock, flags);
list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
- if (ctx->netdev != netdev ||
+ struct net_device *ctx_netdev =
+ rcu_dereference_protected(ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
+
+ if (ctx_netdev != netdev ||
!refcount_inc_not_zero(&ctx->refcount))
continue;
@@ -1346,7 +1381,7 @@ static int tls_device_down(struct net_device *netdev)
/* Stop the RX and TX resync.
* tls_dev_resync must not be called after tls_dev_del.
*/
- WRITE_ONCE(ctx->netdev, NULL);
+ rcu_assign_pointer(ctx->netdev, NULL);
/* Start skipping the RX resync logic completely. */
set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);