diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 9a4c8ff32d9d..5bf7822c53f1 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -6,6 +6,8 @@ #include "allowedips.h" #include "peer.h" +enum { MAX_ALLOWEDIPS_BITS = 128 }; + static struct kmem_cache *node_cache; static void swap_endian(u8 *dst, const u8 *src, u8 bits) @@ -40,7 +42,8 @@ static void push_rcu(struct allowedips_node **stack, struct allowedips_node __rcu *p, unsigned int *len) { if (rcu_access_pointer(p)) { - WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); + if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_BITS)) + return; stack[(*len)++] = rcu_dereference_raw(p); } } @@ -52,7 +55,7 @@ static void node_free_rcu(struct rcu_head *rcu) static void root_free_rcu(struct rcu_head *rcu) { - struct allowedips_node *node, *stack[128] = { + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; @@ -65,7 +68,7 @@ static void root_free_rcu(struct rcu_head *rcu) static void root_remove_peer_lists(struct allowedips_node *root) { - struct allowedips_node *node, *stack[128] = { root }; + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_BITS] = { root }; unsigned int len = 1; while (len > 0 && (node = stack[--len])) { diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index e173204ae7d7..41db10f9be49 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -593,10 +593,10 @@ bool __init wg_allowedips_selftest(void) wg_allowedips_remove_by_peer(&t, a, &mutex); test_negative(4, a, 192, 168, 0, 1); - /* These will hit the WARN_ON(len >= 128) in free_node if something - * goes wrong. + /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_BITS) in free_node + * if something goes wrong. */ - for (i = 0; i < 128; ++i) { + for (i = 0; i < MAX_ALLOWEDIPS_BITS; ++i) { part = cpu_to_be64(~(1LLU << (i % 64))); memset(&ip, 0xff, 16); memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);