diff options
| author | Mike Frysinger <vapier@google.com> | 2018-08-14 17:38:57 -0400 |
|---|---|---|
| committer | Mike Frysinger <vapier@google.com> | 2018-08-14 18:00:12 -0400 |
| commit | 3edca1e85b30cc7de24ac2fb3b3389a6741c9304 (patch) | |
| tree | 713e1ce83d7d96b771803e215707174bae5516d6 /tools | |
| parent | 89cbc32f005628664e7e5dcffd23efc974fa0a52 (diff) | |
| download | platform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.tar.gz platform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.tar.bz2 platform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.zip | |
generate_seccomp_policy.py: rewrite in python3 (and style)
Python2 is on the way out, so rewrite for python3.
Also clean up the style a bit while we're in here.
Bug: None
Test: used `strace ls /proc` log and diffed the filter before & after
Change-Id: Ib96a24c5fd70d6a938bcf15f96a303bd8a34fe9e
Diffstat (limited to 'tools')
| -rwxr-xr-x | tools/generate_seccomp_policy.py | 112 |
1 files changed, 62 insertions, 50 deletions
diff --git a/tools/generate_seccomp_policy.py b/tools/generate_seccomp_policy.py index 7705a060..3f030845 100755 --- a/tools/generate_seccomp_policy.py +++ b/tools/generate_seccomp_policy.py @@ -1,4 +1,5 @@ -#!/usr/bin/python +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- # # Copyright (C) 2016 The Android Open Source Project # @@ -17,10 +18,16 @@ # This script will take any number of trace files generated by strace(1) # and output a system call filtering policy suitable for use with Minijail. -from collections import namedtuple, defaultdict +"""Helper tool to generate a minijail seccomp filter from strace output.""" + +from __future__ import print_function + +import argparse +import collections import re import sys + NOTICE = """# Copyright (C) 2018 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -36,16 +43,17 @@ NOTICE = """# Copyright (C) 2018 The Android Open Source Project # limitations under the License. """ -ALLOW = "%s: 1" +ALLOW = '%s: 1' # This ignores any leading PID tag and trailing <unfinished ...>, and extracts # the syscall name and the argument list. LINE_RE = re.compile(r'^\s*(?:\[[^]]*\]|\d+)?\s*([a-zA-Z0-9_]+)\(([^)<]*)') -SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname", - "getsockopt", "listen", "recv", "recvfrom", "recvmsg", "send", - "sendmsg", "sendto", "setsockopt", "shutdown", "socket", - "socketpair"] +SOCKETCALLS = { + 'accept', 'bind', 'connect', 'getpeername', 'getsockname', 'getsockopt', + 'listen', 'recv', 'recvfrom', 'recvmsg', 'send', 'sendmsg', 'sendto', + 'setsockopt', 'shutdown', 'socket', 'socketpair', +} # /* Protocol families. */ # #define PF_UNSPEC 0 /* Unspecified. */ @@ -68,37 +76,46 @@ SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname", # #define PF_KEY 15 /* PF_KEY key management API. */ # #define PF_NETLINK 16 -ArgInspectionEntry = namedtuple("ArgInspectionEntry", "arg_index value_set") +ArgInspectionEntry = collections.namedtuple('ArgInspectionEntry', + ('arg_index', 'value_set')) -def usage(argv): - print "%s <trace file> [trace files...]" % argv[0] +def get_parser(): + """Return a CLI parser for this tool.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('traces', nargs='+', help='The strace logs.') + return parser -def main(traces): - syscalls = defaultdict(int) +def main(argv): + parser = get_parser() + opts = parser.parse_args(argv) + + syscalls = collections.defaultdict(int) uses_socketcall = False - basic_set = ["restart_syscall", "exit", "exit_group", - "rt_sigreturn"] - frequent_set = [] + basic_set = [ + 'restart_syscall', 'exit', 'exit_group', 'rt_sigreturn', + ] syscall_sets = {} - syscall_set_list = [["sigreturn", "rt_sigreturn"], - ["sigaction", "rt_sigaction"], - ["sigprocmask", "rt_sigprocmask"], - ["open", "openat"], - ["mmap", "mremap"], - ["mmap2", "mremap"]] + syscall_set_list = [ + ['sigreturn', 'rt_sigreturn'], + ['sigaction', 'rt_sigaction'], + ['sigprocmask', 'rt_sigprocmask'], + ['open', 'openat'], + ['mmap', 'mremap'], + ['mmap2', 'mremap'], + ] arg_inspection = { - "socket": ArgInspectionEntry(0, set([])), # int domain - "ioctl": ArgInspectionEntry(1, set([])), # int request - "prctl": ArgInspectionEntry(0, set([])), # int option - "mmap": ArgInspectionEntry(2, set([])), # int prot - "mmap2": ArgInspectionEntry(2, set([])), # int prot - "mprotect": ArgInspectionEntry(2, set([])), # int prot + 'socket': ArgInspectionEntry(0, set([])), # int domain + 'ioctl': ArgInspectionEntry(1, set([])), # int request + 'prctl': ArgInspectionEntry(0, set([])), # int option + 'mmap': ArgInspectionEntry(2, set([])), # int prot + 'mmap2': ArgInspectionEntry(2, set([])), # int prot + 'mprotect': ArgInspectionEntry(2, set([])), # int prot } for syscall_list in syscall_set_list: @@ -107,9 +124,9 @@ def main(traces): other_syscalls.remove(syscall) syscall_sets[syscall] = other_syscalls - for trace_filename in traces: - if "i386" in trace_filename or ("x86" in trace_filename and - "64" not in trace_filename): + for trace_filename in opts.traces: + if 'i386' in trace_filename or ('x86' in trace_filename and + '64' not in trace_filename): uses_socketcall = True trace_file = open(trace_filename) @@ -120,7 +137,7 @@ def main(traces): syscall, args = matches.groups() if uses_socketcall and syscall in SOCKETCALLS: - syscall = "socketcall" + syscall = 'socketcall' syscalls[syscall] += 1 @@ -130,17 +147,16 @@ def main(traces): arg_value = args[arg_inspection[syscall].arg_index] arg_inspection[syscall].value_set.add(arg_value) - sorted_syscalls = list(zip(*sorted(syscalls.iteritems(), - key=lambda pair: pair[1], - reverse=True))[0]) - - print NOTICE + # Sort the syscalls based on frequency. This way the calls that are used + # more often come first which in turn speeds up the filter slightly. + sorted_syscalls = list( + x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[1], + reverse=True) + ) - # Add frequent syscalls first. - for frequent_syscall in frequent_set: - sorted_syscalls.remove(frequent_syscall) + print(NOTICE) - all_syscalls = frequent_set + sorted_syscalls + all_syscalls = sorted_syscalls # Add the basic set once the frequency drops below 2. below_ten_index = -1 @@ -160,16 +176,12 @@ def main(traces): if syscall in arg_inspection: arg_index = arg_inspection[syscall].arg_index arg_values = arg_inspection[syscall].value_set - arg_filter = " || ".join(["arg%d == %s" % (arg_index, arg_value) - for arg_value in arg_values]) - print syscall + ": " + arg_filter + arg_filter = ' || '.join('arg%d == %s' % (arg_index, arg_value) + for arg_value in arg_values) + print(syscall + ': ' + arg_filter) else: - print ALLOW % syscall - + print(ALLOW % syscall) -if __name__ == "__main__": - if len(sys.argv) < 2: - usage(sys.argv) - sys.exit(1) - main(sys.argv[1:]) +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) |
