* [PATCH nft] src: add refcount asserts
@ 2025-10-17 8:13 Florian Westphal
2025-10-19 20:40 ` Pablo Neira Ayuso
0 siblings, 1 reply; 3+ messages in thread
From: Florian Westphal @ 2025-10-17 8:13 UTC (permalink / raw)
To: netfilter-devel; +Cc: Florian Westphal
_get() functions must not be used when refcnt is 0, as expr_free()
releases expressions on 1 -> 0 transition.
Also, check that a refcount would not overflow from UINT_MAX to 0.
This helps catching use-after-free refcounting bugs even when nft
is built without ASAN support.
Signed-off-by: Florian Westphal <fw@strlen.de>
---
include/rule.h | 2 +-
src/expression.c | 12 ++++++++++++
src/rule.c | 28 ++++++++++++++++++++++++++++
3 files changed, 41 insertions(+), 1 deletion(-)
diff --git a/include/rule.h b/include/rule.h
index 8d2f29d09337..bcdc50cad59d 100644
--- a/include/rule.h
+++ b/include/rule.h
@@ -115,7 +115,7 @@ struct symbol {
struct list_head list;
const char *identifier;
struct expr *expr;
- int refcnt;
+ unsigned int refcnt;
};
extern void symbol_bind(struct scope *scope, const char *identifier,
diff --git a/src/expression.c b/src/expression.c
index 019c263f187b..3e74a669c8a4 100644
--- a/src/expression.c
+++ b/src/expression.c
@@ -68,6 +68,11 @@ struct expr *expr_clone(const struct expr *expr)
struct expr *expr_get(struct expr *expr)
{
+ if (expr->refcnt == 0)
+ BUG("refcnt 0, use-after-free on type %s\n", expr_name(expr));
+ if (expr->refcnt == UINT_MAX)
+ BUG("refcnt overflow for type %s\n", expr_name(expr));
+
expr->refcnt++;
return expr;
}
@@ -84,6 +89,10 @@ void expr_free(struct expr *expr)
{
if (expr == NULL)
return;
+
+ if (expr->refcnt == 0)
+ BUG("refcnt 0, possible double-free on type %p %s\n",expr, expr_name(expr));
+
if (--expr->refcnt > 0)
return;
@@ -343,11 +352,14 @@ static void variable_expr_clone(struct expr *new, const struct expr *expr)
new->scope = expr->scope;
new->sym = expr->sym;
+ assert(expr->sym->refcnt > 0);
+ assert(expr->sym->refcnt < UINT_MAX);
expr->sym->refcnt++;
}
static void variable_expr_destroy(struct expr *expr)
{
+ assert(expr->sym->refcnt > 0);
expr->sym->refcnt--;
}
diff --git a/src/rule.c b/src/rule.c
index d0a62a3ee002..722e48ae254b 100644
--- a/src/rule.c
+++ b/src/rule.c
@@ -181,6 +181,8 @@ struct set *set_clone(const struct set *set)
struct set *set_get(struct set *set)
{
+ assert(set->refcnt > 0);
+ assert(set->refcnt < UINT_MAX);
set->refcnt++;
return set;
}
@@ -189,6 +191,7 @@ void set_free(struct set *set)
{
struct stmt *stmt, *next;
+ assert(set->refcnt > 0);
if (--set->refcnt > 0)
return;
@@ -484,12 +487,15 @@ struct rule *rule_alloc(const struct location *loc, const struct handle *h)
struct rule *rule_get(struct rule *rule)
{
+ assert(rule->refcnt > 0);
+ assert(rule->refcnt < UINT_MAX);
rule->refcnt++;
return rule;
}
void rule_free(struct rule *rule)
{
+ assert(rule->refcnt > 0);
if (--rule->refcnt > 0)
return;
stmt_list_free(&rule->stmts);
@@ -606,13 +612,22 @@ struct symbol *symbol_get(const struct scope *scope, const char *identifier)
if (!sym)
return NULL;
+ if (sym->refcnt == 0)
+ BUG("sym->recnt is 0, use-after-free for identifier %s\n", identifier);
+
+ assert(sym->refcnt > 0);
sym->refcnt++;
+ if (sym->refcnt == UINT_MAX)
+ BUG("sym->refcnt overflow, identifier %s\n", identifier);
+
return sym;
}
static void symbol_put(struct symbol *sym)
{
+ assert(sym->refcnt > 0);
+
if (--sym->refcnt == 0) {
free_const(sym->identifier);
expr_free(sym->expr);
@@ -732,6 +747,9 @@ struct chain *chain_alloc(void)
struct chain *chain_get(struct chain *chain)
{
+ assert(chain->refcnt > 0);
+ assert(chain->refcnt < UINT_MAX);
+
chain->refcnt++;
return chain;
}
@@ -741,6 +759,7 @@ void chain_free(struct chain *chain)
struct rule *rule, *next;
int i;
+ assert(chain->refcnt > 0);
if (--chain->refcnt > 0)
return;
list_for_each_entry_safe(rule, next, &chain->rules, list)
@@ -1176,6 +1195,7 @@ void table_free(struct table *table)
struct set *set, *nset;
struct obj *obj, *nobj;
+ assert(table->refcnt > 0);
if (--table->refcnt > 0)
return;
if (table->comment)
@@ -1214,6 +1234,8 @@ void table_free(struct table *table)
struct table *table_get(struct table *table)
{
+ assert(table->refcnt > 0);
+ assert(table->refcnt < UINT_MAX);
table->refcnt++;
return table;
}
@@ -1687,12 +1709,15 @@ struct obj *obj_alloc(const struct location *loc)
struct obj *obj_get(struct obj *obj)
{
+ assert(obj->refcnt > 0);
+ assert(obj->refcnt < UINT_MAX);
obj->refcnt++;
return obj;
}
void obj_free(struct obj *obj)
{
+ assert(obj->refcnt > 0);
if (--obj->refcnt > 0)
return;
free_const(obj->comment);
@@ -2270,6 +2295,8 @@ struct flowtable *flowtable_alloc(const struct location *loc)
struct flowtable *flowtable_get(struct flowtable *flowtable)
{
+ assert(flowtable->refcnt > 0);
+ assert(flowtable->refcnt < UINT_MAX);
flowtable->refcnt++;
return flowtable;
}
@@ -2278,6 +2305,7 @@ void flowtable_free(struct flowtable *flowtable)
{
int i;
+ assert(flowtable->refcnt > 0);
if (--flowtable->refcnt > 0)
return;
handle_free(&flowtable->handle);
--
2.51.0
^ permalink raw reply related [flat|nested] 3+ messages in thread
* Re: [PATCH nft] src: add refcount asserts
2025-10-17 8:13 [PATCH nft] src: add refcount asserts Florian Westphal
@ 2025-10-19 20:40 ` Pablo Neira Ayuso
2025-10-19 20:57 ` Florian Westphal
0 siblings, 1 reply; 3+ messages in thread
From: Pablo Neira Ayuso @ 2025-10-19 20:40 UTC (permalink / raw)
To: Florian Westphal; +Cc: netfilter-devel
Hi Florian,
On Fri, Oct 17, 2025 at 10:13:53AM +0200, Florian Westphal wrote:
> _get() functions must not be used when refcnt is 0, as expr_free()
> releases expressions on 1 -> 0 transition.
>
> Also, check that a refcount would not overflow from UINT_MAX to 0.
> This helps catching use-after-free refcounting bugs even when nft
> is built without ASAN support.
This is indeed needed, thanks. Comments below.
> Signed-off-by: Florian Westphal <fw@strlen.de>
> ---
> include/rule.h | 2 +-
> src/expression.c | 12 ++++++++++++
> src/rule.c | 28 ++++++++++++++++++++++++++++
> 3 files changed, 41 insertions(+), 1 deletion(-)
>
> diff --git a/include/rule.h b/include/rule.h
> index 8d2f29d09337..bcdc50cad59d 100644
> --- a/include/rule.h
> +++ b/include/rule.h
> @@ -115,7 +115,7 @@ struct symbol {
> struct list_head list;
> const char *identifier;
> struct expr *expr;
> - int refcnt;
> + unsigned int refcnt;
> };
>
> extern void symbol_bind(struct scope *scope, const char *identifier,
> diff --git a/src/expression.c b/src/expression.c
> index 019c263f187b..3e74a669c8a4 100644
> --- a/src/expression.c
> +++ b/src/expression.c
> @@ -68,6 +68,11 @@ struct expr *expr_clone(const struct expr *expr)
>
> struct expr *expr_get(struct expr *expr)
> {
> + if (expr->refcnt == 0)
> + BUG("refcnt 0, use-after-free on type %s\n", expr_name(expr));
> + if (expr->refcnt == UINT_MAX)
> + BUG("refcnt overflow for type %s\n", expr_name(expr));
> +
> expr->refcnt++;
> return expr;
> }
> @@ -84,6 +89,10 @@ void expr_free(struct expr *expr)
> {
> if (expr == NULL)
> return;
> +
> + if (expr->refcnt == 0)
> + BUG("refcnt 0, possible double-free on type %p %s\n",expr, expr_name(expr));
> +
> if (--expr->refcnt > 0)
> return;
>
> @@ -343,11 +352,14 @@ static void variable_expr_clone(struct expr *new, const struct expr *expr)
> new->scope = expr->scope;
> new->sym = expr->sym;
>
> + assert(expr->sym->refcnt > 0);
> + assert(expr->sym->refcnt < UINT_MAX);
Would it be possible to consolidate all this with a macro, eg.
assert_refcount_safe(expr->sym->refcnt);
> expr->sym->refcnt++;
> }
>
> static void variable_expr_destroy(struct expr *expr)
> {
> + assert(expr->sym->refcnt > 0);
> expr->sym->refcnt--;
> }
>
> diff --git a/src/rule.c b/src/rule.c
> index d0a62a3ee002..722e48ae254b 100644
> --- a/src/rule.c
> +++ b/src/rule.c
> @@ -181,6 +181,8 @@ struct set *set_clone(const struct set *set)
>
> struct set *set_get(struct set *set)
> {
> + assert(set->refcnt > 0);
> + assert(set->refcnt < UINT_MAX);
> set->refcnt++;
> return set;
> }
> @@ -189,6 +191,7 @@ void set_free(struct set *set)
> {
> struct stmt *stmt, *next;
>
> + assert(set->refcnt > 0);
> if (--set->refcnt > 0)
> return;
>
> @@ -484,12 +487,15 @@ struct rule *rule_alloc(const struct location *loc, const struct handle *h)
>
> struct rule *rule_get(struct rule *rule)
> {
> + assert(rule->refcnt > 0);
> + assert(rule->refcnt < UINT_MAX);
> rule->refcnt++;
> return rule;
> }
>
> void rule_free(struct rule *rule)
> {
> + assert(rule->refcnt > 0);
> if (--rule->refcnt > 0)
> return;
> stmt_list_free(&rule->stmts);
> @@ -606,13 +612,22 @@ struct symbol *symbol_get(const struct scope *scope, const char *identifier)
> if (!sym)
> return NULL;
>
> + if (sym->refcnt == 0)
> + BUG("sym->recnt is 0, use-after-free for identifier %s\n", identifier);
^^^^^
typo
Maybe simply use assert() everywhere, or add BUG_ON(cond, text) where
this refcount checks are done.
> +
> + assert(sym->refcnt > 0);
> sym->refcnt++;
>
> + if (sym->refcnt == UINT_MAX)
> + BUG("sym->refcnt overflow, identifier %s\n", identifier);
> +
> return sym;
> }
>
> static void symbol_put(struct symbol *sym)
> {
> + assert(sym->refcnt > 0);
> +
> if (--sym->refcnt == 0) {
> free_const(sym->identifier);
> expr_free(sym->expr);
> @@ -732,6 +747,9 @@ struct chain *chain_alloc(void)
>
> struct chain *chain_get(struct chain *chain)
> {
> + assert(chain->refcnt > 0);
> + assert(chain->refcnt < UINT_MAX);
> +
> chain->refcnt++;
> return chain;
> }
> @@ -741,6 +759,7 @@ void chain_free(struct chain *chain)
> struct rule *rule, *next;
> int i;
>
> + assert(chain->refcnt > 0);
> if (--chain->refcnt > 0)
> return;
> list_for_each_entry_safe(rule, next, &chain->rules, list)
> @@ -1176,6 +1195,7 @@ void table_free(struct table *table)
> struct set *set, *nset;
> struct obj *obj, *nobj;
>
> + assert(table->refcnt > 0);
> if (--table->refcnt > 0)
> return;
> if (table->comment)
> @@ -1214,6 +1234,8 @@ void table_free(struct table *table)
>
> struct table *table_get(struct table *table)
> {
> + assert(table->refcnt > 0);
> + assert(table->refcnt < UINT_MAX);
> table->refcnt++;
> return table;
> }
> @@ -1687,12 +1709,15 @@ struct obj *obj_alloc(const struct location *loc)
>
> struct obj *obj_get(struct obj *obj)
> {
> + assert(obj->refcnt > 0);
> + assert(obj->refcnt < UINT_MAX);
> obj->refcnt++;
> return obj;
> }
>
> void obj_free(struct obj *obj)
> {
> + assert(obj->refcnt > 0);
> if (--obj->refcnt > 0)
> return;
> free_const(obj->comment);
> @@ -2270,6 +2295,8 @@ struct flowtable *flowtable_alloc(const struct location *loc)
>
> struct flowtable *flowtable_get(struct flowtable *flowtable)
> {
> + assert(flowtable->refcnt > 0);
> + assert(flowtable->refcnt < UINT_MAX);
> flowtable->refcnt++;
> return flowtable;
> }
> @@ -2278,6 +2305,7 @@ void flowtable_free(struct flowtable *flowtable)
> {
> int i;
>
> + assert(flowtable->refcnt > 0);
> if (--flowtable->refcnt > 0)
> return;
> handle_free(&flowtable->handle);
> --
> 2.51.0
>
>
^ permalink raw reply [flat|nested] 3+ messages in thread
* Re: [PATCH nft] src: add refcount asserts
2025-10-19 20:40 ` Pablo Neira Ayuso
@ 2025-10-19 20:57 ` Florian Westphal
0 siblings, 0 replies; 3+ messages in thread
From: Florian Westphal @ 2025-10-19 20:57 UTC (permalink / raw)
To: Pablo Neira Ayuso; +Cc: netfilter-devel
Pablo Neira Ayuso <pablo@netfilter.org> wrote:
> > @@ -343,11 +352,14 @@ static void variable_expr_clone(struct expr *new, const struct expr *expr)
> > new->scope = expr->scope;
> > new->sym = expr->sym;
> >
> > + assert(expr->sym->refcnt > 0);
> > + assert(expr->sym->refcnt < UINT_MAX);
>
> Would it be possible to consolidate all this with a macro, eg.
>
> assert_refcount_safe(expr->sym->refcnt);
Sure, makes sense to me, I'll spin a v2.
^ permalink raw reply [flat|nested] 3+ messages in thread
end of thread, other threads:[~2025-10-19 20:57 UTC | newest]
Thread overview: 3+ messages (download: mbox.gz follow: Atom feed
-- links below jump to the message on this page --
2025-10-17 8:13 [PATCH nft] src: add refcount asserts Florian Westphal
2025-10-19 20:40 ` Pablo Neira Ayuso
2025-10-19 20:57 ` Florian Westphal
This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for NNTP newsgroup(s).