diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 656879385fbf..5fec2d1be6d7 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -2795,10 +2795,13 @@ struct btf_id_set; bool btf_id_set_contains(const struct btf_id_set *set, u32 id); #define MAX_BPRINTF_VARARGS 12 +#define MAX_BPRINTF_BUF 1024 struct bpf_bprintf_data { u32 *bin_args; + char *buf; bool get_bin_args; + bool get_buf; }; int bpf_bprintf_prepare(char *fmt, u32 fmt_size, const u64 *raw_args, diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c index 9cca02e13f2e..23aa8cf8fd1a 100644 --- a/kernel/bpf/helpers.c +++ b/kernel/bpf/helpers.c @@ -756,19 +756,20 @@ static int bpf_trace_copy_string(char *buf, void *unsafe_ptr, char fmt_ptype, /* Per-cpu temp buffers used by printf-like helpers to store the bprintf binary * arguments representation. */ -#define MAX_BPRINTF_BUF_LEN 512 +#define MAX_BPRINTF_BIN_ARGS 512 /* Support executing three nested bprintf helper calls on a given CPU */ #define MAX_BPRINTF_NEST_LEVEL 3 struct bpf_bprintf_buffers { - char tmp_bufs[MAX_BPRINTF_NEST_LEVEL][MAX_BPRINTF_BUF_LEN]; + char bin_args[MAX_BPRINTF_BIN_ARGS]; + char buf[MAX_BPRINTF_BUF]; }; -static DEFINE_PER_CPU(struct bpf_bprintf_buffers, bpf_bprintf_bufs); + +static DEFINE_PER_CPU(struct bpf_bprintf_buffers[MAX_BPRINTF_NEST_LEVEL], bpf_bprintf_bufs); static DEFINE_PER_CPU(int, bpf_bprintf_nest_level); -static int try_get_fmt_tmp_buf(char **tmp_buf) +static int try_get_buffers(struct bpf_bprintf_buffers **bufs) { - struct bpf_bprintf_buffers *bufs; int nest_level; preempt_disable(); @@ -778,15 +779,14 @@ static int try_get_fmt_tmp_buf(char **tmp_buf) preempt_enable(); return -EBUSY; } - bufs = this_cpu_ptr(&bpf_bprintf_bufs); - *tmp_buf = bufs->tmp_bufs[nest_level - 1]; + *bufs = this_cpu_ptr(&bpf_bprintf_bufs[nest_level - 1]); return 0; } void bpf_bprintf_cleanup(struct bpf_bprintf_data *data) { - if (!data->bin_args) + if (!data->bin_args && !data->buf) return; if (WARN_ON_ONCE(this_cpu_read(bpf_bprintf_nest_level) == 0)) return; @@ -811,7 +811,9 @@ void bpf_bprintf_cleanup(struct bpf_bprintf_data *data) int bpf_bprintf_prepare(char *fmt, u32 fmt_size, const u64 *raw_args, u32 num_args, struct bpf_bprintf_data *data) { + bool get_buffers = (data->get_bin_args && num_args) || data->get_buf; char *unsafe_ptr = NULL, *tmp_buf = NULL, *tmp_buf_end, *fmt_end; + struct bpf_bprintf_buffers *buffers = NULL; size_t sizeof_cur_arg, sizeof_cur_ip; int err, i, num_spec = 0; u64 cur_arg; @@ -822,14 +824,19 @@ int bpf_bprintf_prepare(char *fmt, u32 fmt_size, const u64 *raw_args, return -EINVAL; fmt_size = fmt_end - fmt; - if (data->get_bin_args) { - if (num_args && try_get_fmt_tmp_buf(&tmp_buf)) - return -EBUSY; + if (get_buffers && try_get_buffers(&buffers)) + return -EBUSY; - tmp_buf_end = tmp_buf + MAX_BPRINTF_BUF_LEN; + if (data->get_bin_args) { + if (num_args) + tmp_buf = buffers->bin_args; + tmp_buf_end = tmp_buf + MAX_BPRINTF_BIN_ARGS; data->bin_args = (u32 *)tmp_buf; } + if (data->get_buf) + data->buf = buffers->buf; + for (i = 0; i < fmt_size; i++) { if ((!isprint(fmt[i]) && !isspace(fmt[i])) || !isascii(fmt[i])) { err = -EINVAL; diff --git a/kernel/trace/bpf_trace.c b/kernel/trace/bpf_trace.c index 2129f7c68bb5..23ce498bca97 100644 --- a/kernel/trace/bpf_trace.c +++ b/kernel/trace/bpf_trace.c @@ -369,8 +369,6 @@ static const struct bpf_func_proto *bpf_get_probe_write_proto(void) return &bpf_probe_write_user_proto; } -static DEFINE_RAW_SPINLOCK(trace_printk_lock); - #define MAX_TRACE_PRINTK_VARARGS 3 #define BPF_TRACE_PRINTK_SIZE 1024 @@ -380,9 +378,8 @@ BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1, u64 args[MAX_TRACE_PRINTK_VARARGS] = { arg1, arg2, arg3 }; struct bpf_bprintf_data data = { .get_bin_args = true, + .get_buf = true, }; - static char buf[BPF_TRACE_PRINTK_SIZE]; - unsigned long flags; int ret; ret = bpf_bprintf_prepare(fmt, fmt_size, args, @@ -390,11 +387,9 @@ BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1, if (ret < 0) return ret; - raw_spin_lock_irqsave(&trace_printk_lock, flags); - ret = bstr_printf(buf, sizeof(buf), fmt, data.bin_args); + ret = bstr_printf(data.buf, MAX_BPRINTF_BUF, fmt, data.bin_args); - trace_bpf_trace_printk(buf); - raw_spin_unlock_irqrestore(&trace_printk_lock, flags); + trace_bpf_trace_printk(data.buf); bpf_bprintf_cleanup(&data); @@ -434,9 +429,8 @@ BPF_CALL_4(bpf_trace_vprintk, char *, fmt, u32, fmt_size, const void *, args, { struct bpf_bprintf_data data = { .get_bin_args = true, + .get_buf = true, }; - static char buf[BPF_TRACE_PRINTK_SIZE]; - unsigned long flags; int ret, num_args; if (data_len & 7 || data_len > MAX_BPRINTF_VARARGS * 8 || @@ -448,11 +442,9 @@ BPF_CALL_4(bpf_trace_vprintk, char *, fmt, u32, fmt_size, const void *, args, if (ret < 0) return ret; - raw_spin_lock_irqsave(&trace_printk_lock, flags); - ret = bstr_printf(buf, sizeof(buf), fmt, data.bin_args); + ret = bstr_printf(data.buf, MAX_BPRINTF_BUF, fmt, data.bin_args); - trace_bpf_trace_printk(buf); - raw_spin_unlock_irqrestore(&trace_printk_lock, flags); + trace_bpf_trace_printk(data.buf); bpf_bprintf_cleanup(&data);