summaryrefslogtreecommitdiff
path: root/arch/x86/net/bpf_jit_comp.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86/net/bpf_jit_comp.c')
-rw-r--r--arch/x86/net/bpf_jit_comp.c115
1 files changed, 80 insertions, 35 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 0554e8aef4d5..45e4eb5bcbb2 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -13,10 +13,9 @@
#include <linux/if_vlan.h>
#include <asm/cacheflush.h>
#include <asm/set_memory.h>
+#include <asm/nospec-branch.h>
#include <linux/bpf.h>
-int bpf_jit_enable __read_mostly;
-
/*
* assembly code in arch/x86/net/bpf_jit.S
*/
@@ -154,6 +153,11 @@ static bool is_ereg(u32 reg)
BIT(BPF_REG_AX));
}
+static bool is_axreg(u32 reg)
+{
+ return reg == BPF_REG_0;
+}
+
/* add modifiers if 'reg' maps to x64 registers r8..r15 */
static u8 add_1mod(u8 byte, u32 reg)
{
@@ -287,7 +291,7 @@ static void emit_bpf_tail_call(u8 **pprog)
EMIT2(0x89, 0xD2); /* mov edx, edx */
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 43 /* number of bytes to jump */
+#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
EMIT2(X86_JBE, OFFSET1); /* jbe out */
label1 = cnt;
@@ -296,7 +300,7 @@ static void emit_bpf_tail_call(u8 **pprog)
*/
EMIT2_off32(0x8B, 0x85, 36); /* mov eax, dword ptr [rbp + 36] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 32
+#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
EMIT2(X86_JA, OFFSET2); /* ja out */
label2 = cnt;
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
@@ -310,7 +314,7 @@ static void emit_bpf_tail_call(u8 **pprog)
* goto out;
*/
EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */
-#define OFFSET3 10
+#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
EMIT2(X86_JE, OFFSET3); /* je out */
label3 = cnt;
@@ -323,7 +327,7 @@ static void emit_bpf_tail_call(u8 **pprog)
* rdi == ctx (1st arg)
* rax == prog->bpf_func + prologue_size
*/
- EMIT2(0xFF, 0xE0); /* jmp rax */
+ RETPOLINE_RAX_BPF_JIT();
/* out: */
BUILD_BUG_ON(cnt - label1 != OFFSET1);
@@ -447,16 +451,36 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
else if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
+ /* b3 holds 'normal' opcode, b2 short form only valid
+ * in case dst is eax/rax.
+ */
switch (BPF_OP(insn->code)) {
- case BPF_ADD: b3 = 0xC0; break;
- case BPF_SUB: b3 = 0xE8; break;
- case BPF_AND: b3 = 0xE0; break;
- case BPF_OR: b3 = 0xC8; break;
- case BPF_XOR: b3 = 0xF0; break;
+ case BPF_ADD:
+ b3 = 0xC0;
+ b2 = 0x05;
+ break;
+ case BPF_SUB:
+ b3 = 0xE8;
+ b2 = 0x2D;
+ break;
+ case BPF_AND:
+ b3 = 0xE0;
+ b2 = 0x25;
+ break;
+ case BPF_OR:
+ b3 = 0xC8;
+ b2 = 0x0D;
+ break;
+ case BPF_XOR:
+ b3 = 0xF0;
+ b2 = 0x35;
+ break;
}
if (is_imm8(imm32))
EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
+ else if (is_axreg(dst_reg))
+ EMIT1_off32(b2, imm32);
else
EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
break;
@@ -545,26 +569,6 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
*/
EMIT2(0x31, 0xd2);
- if (BPF_SRC(insn->code) == BPF_X) {
- /* if (src_reg == 0) return 0 */
-
- /* cmp r11, 0 */
- EMIT4(0x49, 0x83, 0xFB, 0x00);
-
- /* jne .+9 (skip over pop, pop, xor and jmp) */
- EMIT2(X86_JNE, 1 + 1 + 2 + 5);
- EMIT1(0x5A); /* pop rdx */
- EMIT1(0x58); /* pop rax */
- EMIT2(0x31, 0xc0); /* xor eax, eax */
-
- /* jmp cleanup_addr
- * addrs[i] - 11, because there are 11 bytes
- * after this insn: div, mov, pop, pop, mov
- */
- jmp_offset = ctx->cleanup_addr - (addrs[i] - 11);
- EMIT1_off32(0xE9, jmp_offset);
- }
-
if (BPF_CLASS(insn->code) == BPF_ALU64)
/* div r11 */
EMIT3(0x49, 0xF7, 0xF3);
@@ -1109,19 +1113,29 @@ common_load:
return proglen;
}
+struct x64_jit_data {
+ struct bpf_binary_header *header;
+ int *addrs;
+ u8 *image;
+ int proglen;
+ struct jit_context ctx;
+};
+
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
{
struct bpf_binary_header *header = NULL;
struct bpf_prog *tmp, *orig_prog = prog;
+ struct x64_jit_data *jit_data;
int proglen, oldproglen = 0;
struct jit_context ctx = {};
bool tmp_blinded = false;
+ bool extra_pass = false;
u8 *image = NULL;
int *addrs;
int pass;
int i;
- if (!bpf_jit_enable)
+ if (!prog->jit_requested)
return orig_prog;
tmp = bpf_jit_blind_constants(prog);
@@ -1135,10 +1149,28 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
prog = tmp;
}
+ jit_data = prog->aux->jit_data;
+ if (!jit_data) {
+ jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+ if (!jit_data) {
+ prog = orig_prog;
+ goto out;
+ }
+ prog->aux->jit_data = jit_data;
+ }
+ addrs = jit_data->addrs;
+ if (addrs) {
+ ctx = jit_data->ctx;
+ oldproglen = jit_data->proglen;
+ image = jit_data->image;
+ header = jit_data->header;
+ extra_pass = true;
+ goto skip_init_addrs;
+ }
addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
if (!addrs) {
prog = orig_prog;
- goto out;
+ goto out_addrs;
}
/* Before first pass, make a rough estimation of addrs[]
@@ -1149,6 +1181,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
addrs[i] = proglen;
}
ctx.cleanup_addr = proglen;
+skip_init_addrs:
/* JITed image shrinks with every pass and the loop iterates
* until the image stops shrinking. Very large bpf programs
@@ -1189,7 +1222,15 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
if (image) {
bpf_flush_icache(header, image + proglen);
- bpf_jit_binary_lock_ro(header);
+ if (!prog->is_func || extra_pass) {
+ bpf_jit_binary_lock_ro(header);
+ } else {
+ jit_data->addrs = addrs;
+ jit_data->ctx = ctx;
+ jit_data->proglen = proglen;
+ jit_data->image = image;
+ jit_data->header = header;
+ }
prog->bpf_func = (void *)image;
prog->jited = 1;
prog->jited_len = proglen;
@@ -1197,8 +1238,12 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
prog = orig_prog;
}
+ if (!prog->is_func || extra_pass) {
out_addrs:
- kfree(addrs);
+ kfree(addrs);
+ kfree(jit_data);
+ prog->aux->jit_data = NULL;
+ }
out:
if (tmp_blinded)
bpf_jit_prog_release_other(prog, prog == orig_prog ?