diff options
Diffstat (limited to 'net/l2tp/l2tp_core.c')
-rw-r--r-- | net/l2tp/l2tp_core.c | 204 |
1 files changed, 73 insertions, 131 deletions
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index b0c2d4ae781d..115918ad8eca 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -100,8 +100,6 @@ struct l2tp_skb_cb { #define L2TP_SKB_CB(skb) ((struct l2tp_skb_cb *) &skb->cb[sizeof(struct inet_skb_parm)]) -static atomic_t l2tp_tunnel_count; -static atomic_t l2tp_session_count; static struct workqueue_struct *l2tp_wq; /* per-net private data for this module */ @@ -113,7 +111,6 @@ struct l2tp_net { spinlock_t l2tp_session_hlist_lock; }; -static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel); static inline struct l2tp_tunnel *l2tp_tunnel(struct sock *sk) { @@ -127,39 +124,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) return net_generic(net, l2tp_net_id); } -/* Tunnel reference counts. Incremented per session that is added to - * the tunnel. - */ -static inline void l2tp_tunnel_inc_refcount_1(struct l2tp_tunnel *tunnel) -{ - refcount_inc(&tunnel->ref_count); -} - -static inline void l2tp_tunnel_dec_refcount_1(struct l2tp_tunnel *tunnel) -{ - if (refcount_dec_and_test(&tunnel->ref_count)) - l2tp_tunnel_free(tunnel); -} -#ifdef L2TP_REFCNT_DEBUG -#define l2tp_tunnel_inc_refcount(_t) \ -do { \ - pr_debug("l2tp_tunnel_inc_refcount: %s:%d %s: cnt=%d\n", \ - __func__, __LINE__, (_t)->name, \ - refcount_read(&_t->ref_count)); \ - l2tp_tunnel_inc_refcount_1(_t); \ -} while (0) -#define l2tp_tunnel_dec_refcount(_t) \ -do { \ - pr_debug("l2tp_tunnel_dec_refcount: %s:%d %s: cnt=%d\n", \ - __func__, __LINE__, (_t)->name, \ - refcount_read(&_t->ref_count)); \ - l2tp_tunnel_dec_refcount_1(_t); \ -} while (0) -#else -#define l2tp_tunnel_inc_refcount(t) l2tp_tunnel_inc_refcount_1(t) -#define l2tp_tunnel_dec_refcount(t) l2tp_tunnel_dec_refcount_1(t) -#endif - /* Session hash global list for L2TPv3. * The session_id SHOULD be random according to RFC3931, but several * L2TP implementations use incrementing session_ids. So we do a real @@ -229,12 +193,31 @@ l2tp_session_id_hash(struct l2tp_tunnel *tunnel, u32 session_id) return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)]; } -/* Lookup a session. A new reference is held on the returned session. - * Optionally calls session->ref() too if do_ref is true. - */ +/* Lookup a tunnel. A new reference is held on the returned tunnel. */ +struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) +{ + const struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_tunnel *tunnel; + + rcu_read_lock_bh(); + list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { + if (tunnel->tunnel_id == tunnel_id) { + l2tp_tunnel_inc_refcount(tunnel); + rcu_read_unlock_bh(); + + return tunnel; + } + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_get); + +/* Lookup a session. A new reference is held on the returned session. */ struct l2tp_session *l2tp_session_get(const struct net *net, struct l2tp_tunnel *tunnel, - u32 session_id, bool do_ref) + u32 session_id) { struct hlist_head *session_list; struct l2tp_session *session; @@ -248,8 +231,6 @@ struct l2tp_session *l2tp_session_get(const struct net *net, hlist_for_each_entry_rcu(session, session_list, global_hlist) { if (session->session_id == session_id) { l2tp_session_inc_refcount(session); - if (do_ref && session->ref) - session->ref(session); rcu_read_unlock_bh(); return session; @@ -265,8 +246,6 @@ struct l2tp_session *l2tp_session_get(const struct net *net, hlist_for_each_entry(session, session_list, hlist) { if (session->session_id == session_id) { l2tp_session_inc_refcount(session); - if (do_ref && session->ref) - session->ref(session); read_unlock_bh(&tunnel->hlist_lock); return session; @@ -278,8 +257,7 @@ struct l2tp_session *l2tp_session_get(const struct net *net, } EXPORT_SYMBOL_GPL(l2tp_session_get); -struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth, - bool do_ref) +struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) { int hash; struct l2tp_session *session; @@ -290,8 +268,6 @@ struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth, hlist_for_each_entry(session, &tunnel->session_hlist[hash], hlist) { if (++count > nth) { l2tp_session_inc_refcount(session); - if (do_ref && session->ref) - session->ref(session); read_unlock_bh(&tunnel->hlist_lock); return session; } @@ -308,8 +284,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_get_nth); * This is very inefficient but is only used by management interfaces. */ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, - const char *ifname, - bool do_ref) + const char *ifname) { struct l2tp_net *pn = l2tp_pernet(net); int hash; @@ -320,8 +295,6 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { if (!strcmp(session->ifname, ifname)) { l2tp_session_inc_refcount(session); - if (do_ref && session->ref) - session->ref(session); rcu_read_unlock_bh(); return session; @@ -335,20 +308,28 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, } EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname); -static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, - struct l2tp_session *session) +int l2tp_session_register(struct l2tp_session *session, + struct l2tp_tunnel *tunnel) { struct l2tp_session *session_walk; struct hlist_head *g_head; struct hlist_head *head; struct l2tp_net *pn; + int err; head = l2tp_session_id_hash(tunnel, session->session_id); write_lock_bh(&tunnel->hlist_lock); + if (!tunnel->acpt_newsess) { + err = -ENODEV; + goto err_tlock; + } + hlist_for_each_entry(session_walk, head, hlist) - if (session_walk->session_id == session->session_id) - goto exist; + if (session_walk->session_id == session->session_id) { + err = -EEXIST; + goto err_tlock; + } if (tunnel->version == L2TP_HDR_VER_3) { pn = l2tp_pernet(tunnel->l2tp_net); @@ -356,12 +337,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, session->session_id); spin_lock_bh(&pn->l2tp_session_hlist_lock); + hlist_for_each_entry(session_walk, g_head, global_hlist) - if (session_walk->session_id == session->session_id) - goto exist_glob; + if (session_walk->session_id == session->session_id) { + err = -EEXIST; + goto err_tlock_pnlock; + } + l2tp_tunnel_inc_refcount(tunnel); + sock_hold(tunnel->sock); hlist_add_head_rcu(&session->global_hlist, g_head); + spin_unlock_bh(&pn->l2tp_session_hlist_lock); + } else { + l2tp_tunnel_inc_refcount(tunnel); + sock_hold(tunnel->sock); } hlist_add_head(&session->hlist, head); @@ -369,13 +359,14 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel, return 0; -exist_glob: +err_tlock_pnlock: spin_unlock_bh(&pn->l2tp_session_hlist_lock); -exist: +err_tlock: write_unlock_bh(&tunnel->hlist_lock); - return -EEXIST; + return err; } +EXPORT_SYMBOL_GPL(l2tp_session_register); /* Lookup a tunnel by id */ @@ -480,9 +471,6 @@ static void l2tp_recv_dequeue_skb(struct l2tp_session *session, struct sk_buff * (*session->recv_skb)(session, skb, L2TP_SKB_CB(skb)->length); else kfree_skb(skb); - - if (session->deref) - (*session->deref)(session); } /* Dequeue skbs from the session's reorder_q, subject to packet order. @@ -511,8 +499,6 @@ start: session->reorder_skip = 1; __skb_unlink(skb, &session->reorder_q); kfree_skb(skb); - if (session->deref) - (*session->deref)(session); continue; } @@ -685,9 +671,6 @@ discard: * a data (not control) frame before coming here. Fields up to the * session-id have already been parsed and ptr points to the data * after the session-id. - * - * session->ref() must have been called prior to l2tp_recv_common(). - * session->deref() will be called automatically after skb is processed. */ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, unsigned char *ptr, unsigned char *optr, u16 hdrflags, @@ -854,9 +837,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, discard: atomic_long_inc(&session->stats.rx_errors); kfree_skb(skb); - - if (session->deref) - (*session->deref)(session); } EXPORT_SYMBOL(l2tp_recv_common); @@ -870,8 +850,6 @@ int l2tp_session_queue_purge(struct l2tp_session *session) while ((skb = skb_dequeue(&session->reorder_q))) { atomic_long_inc(&session->stats.rx_errors); kfree_skb(skb); - if (session->deref) - (*session->deref)(session); } return 0; } @@ -963,13 +941,10 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb, } /* Find the session context */ - session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true); + session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id); if (!session || !session->recv_skb) { - if (session) { - if (session->deref) - session->deref(session); + if (session) l2tp_session_dec_refcount(session); - } /* Not found? Pass to userspace to deal with */ l2tp_info(tunnel, L2TP_MSG_DATA, @@ -1264,16 +1239,14 @@ static void l2tp_tunnel_destruct(struct sock *sk) /* Remove hooks into tunnel socket */ sk->sk_destruct = tunnel->old_sk_destruct; sk->sk_user_data = NULL; - tunnel->sock = NULL; /* Remove the tunnel struct from the tunnel list */ pn = l2tp_pernet(tunnel->l2tp_net); spin_lock_bh(&pn->l2tp_tunnel_list_lock); list_del_rcu(&tunnel->list); spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - atomic_dec(&l2tp_tunnel_count); - l2tp_tunnel_closeall(tunnel); + tunnel->sock = NULL; l2tp_tunnel_dec_refcount(tunnel); /* Call the original destructor */ @@ -1298,6 +1271,7 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) tunnel->name); write_lock_bh(&tunnel->hlist_lock); + tunnel->acpt_newsess = false; for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { again: hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) { @@ -1308,8 +1282,8 @@ again: hlist_del_init(&session->hlist); - if (session->ref != NULL) - (*session->ref)(session); + if (test_and_set_bit(0, &session->dead)) + goto again; write_unlock_bh(&tunnel->hlist_lock); @@ -1319,9 +1293,6 @@ again: if (session->session_close != NULL) (*session->session_close)(session); - if (session->deref != NULL) - (*session->deref)(session); - l2tp_session_dec_refcount(session); write_lock_bh(&tunnel->hlist_lock); @@ -1348,17 +1319,6 @@ static void l2tp_udp_encap_destroy(struct sock *sk) } } -/* Really kill the tunnel. - * Come here only when all sessions have been cleared from the tunnel. - */ -static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel) -{ - BUG_ON(refcount_read(&tunnel->ref_count) != 0); - BUG_ON(tunnel->sock != NULL); - l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: free...\n", tunnel->name); - kfree_rcu(tunnel, rcu); -} - /* Workqueue tunnel deletion function */ static void l2tp_tunnel_del_work(struct work_struct *work) { @@ -1605,6 +1565,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 tunnel->magic = L2TP_TUNNEL_MAGIC; sprintf(&tunnel->name[0], "tunl %u", tunnel_id); rwlock_init(&tunnel->hlist_lock); + tunnel->acpt_newsess = true; /* The net we belong to */ tunnel->l2tp_net = net; @@ -1662,7 +1623,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 /* Add tunnel to our list */ INIT_LIST_HEAD(&tunnel->list); - atomic_inc(&l2tp_tunnel_count); /* Bump the reference count. The tunnel context is deleted * only when this drops to zero. Must be done before list insertion @@ -1689,14 +1649,12 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_create); /* This function is used by the netlink TUNNEL_DELETE command. */ -int l2tp_tunnel_delete(struct l2tp_tunnel *tunnel) +void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel) { - l2tp_tunnel_inc_refcount(tunnel); - if (false == queue_work(l2tp_wq, &tunnel->del_work)) { - l2tp_tunnel_dec_refcount(tunnel); - return 1; + if (!test_and_set_bit(0, &tunnel->dead)) { + l2tp_tunnel_inc_refcount(tunnel); + queue_work(l2tp_wq, &tunnel->del_work); } - return 0; } EXPORT_SYMBOL_GPL(l2tp_tunnel_delete); @@ -1710,8 +1668,6 @@ void l2tp_session_free(struct l2tp_session *session) if (tunnel) { BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC); - if (session->session_id != 0) - atomic_dec(&l2tp_session_count); sock_put(tunnel->sock); session->tunnel = NULL; l2tp_tunnel_dec_refcount(tunnel); @@ -1754,15 +1710,16 @@ EXPORT_SYMBOL_GPL(__l2tp_session_unhash); */ int l2tp_session_delete(struct l2tp_session *session) { - if (session->ref) - (*session->ref)(session); + if (test_and_set_bit(0, &session->dead)) + return 0; + __l2tp_session_unhash(session); l2tp_session_queue_purge(session); if (session->session_close != NULL) (*session->session_close)(session); - if (session->deref) - (*session->deref)(session); + l2tp_session_dec_refcount(session); + return 0; } EXPORT_SYMBOL_GPL(l2tp_session_delete); @@ -1788,7 +1745,6 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len); struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg) { struct l2tp_session *session; - int err; session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL); if (session != NULL) { @@ -1844,25 +1800,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn l2tp_session_set_header_len(session, tunnel->version); - err = l2tp_session_add_to_tunnel(tunnel, session); - if (err) { - kfree(session); - - return ERR_PTR(err); - } - - /* Bump the reference count. The session context is deleted - * only when this drops to zero. - */ refcount_set(&session->ref_count, 1); - l2tp_tunnel_inc_refcount(tunnel); - - /* Ensure tunnel socket isn't deleted */ - sock_hold(tunnel->sock); - - /* Ignore management session in session count value */ - if (session->session_id != 0) - atomic_inc(&l2tp_session_count); return session; } @@ -1895,15 +1833,19 @@ static __net_exit void l2tp_exit_net(struct net *net) { struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_tunnel *tunnel = NULL; + int hash; rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - (void)l2tp_tunnel_delete(tunnel); + l2tp_tunnel_delete(tunnel); } rcu_read_unlock_bh(); flush_workqueue(l2tp_wq); rcu_barrier(); + + for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) + WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); } static struct pernet_operations l2tp_net_ops = { |