aboutsummaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorMike Frysinger <vapier@google.com>2018-08-14 17:38:57 -0400
committerMike Frysinger <vapier@google.com>2018-08-14 18:00:12 -0400
commit3edca1e85b30cc7de24ac2fb3b3389a6741c9304 (patch)
tree713e1ce83d7d96b771803e215707174bae5516d6 /tools
parent89cbc32f005628664e7e5dcffd23efc974fa0a52 (diff)
downloadplatform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.tar.gz
platform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.tar.bz2
platform_external_minijail-3edca1e85b30cc7de24ac2fb3b3389a6741c9304.zip
generate_seccomp_policy.py: rewrite in python3 (and style)
Python2 is on the way out, so rewrite for python3. Also clean up the style a bit while we're in here. Bug: None Test: used `strace ls /proc` log and diffed the filter before & after Change-Id: Ib96a24c5fd70d6a938bcf15f96a303bd8a34fe9e
Diffstat (limited to 'tools')
-rwxr-xr-xtools/generate_seccomp_policy.py112
1 files changed, 62 insertions, 50 deletions
diff --git a/tools/generate_seccomp_policy.py b/tools/generate_seccomp_policy.py
index 7705a060..3f030845 100755
--- a/tools/generate_seccomp_policy.py
+++ b/tools/generate_seccomp_policy.py
@@ -1,4 +1,5 @@
-#!/usr/bin/python
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
#
# Copyright (C) 2016 The Android Open Source Project
#
@@ -17,10 +18,16 @@
# This script will take any number of trace files generated by strace(1)
# and output a system call filtering policy suitable for use with Minijail.
-from collections import namedtuple, defaultdict
+"""Helper tool to generate a minijail seccomp filter from strace output."""
+
+from __future__ import print_function
+
+import argparse
+import collections
import re
import sys
+
NOTICE = """# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -36,16 +43,17 @@ NOTICE = """# Copyright (C) 2018 The Android Open Source Project
# limitations under the License.
"""
-ALLOW = "%s: 1"
+ALLOW = '%s: 1'
# This ignores any leading PID tag and trailing <unfinished ...>, and extracts
# the syscall name and the argument list.
LINE_RE = re.compile(r'^\s*(?:\[[^]]*\]|\d+)?\s*([a-zA-Z0-9_]+)\(([^)<]*)')
-SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname",
- "getsockopt", "listen", "recv", "recvfrom", "recvmsg", "send",
- "sendmsg", "sendto", "setsockopt", "shutdown", "socket",
- "socketpair"]
+SOCKETCALLS = {
+ 'accept', 'bind', 'connect', 'getpeername', 'getsockname', 'getsockopt',
+ 'listen', 'recv', 'recvfrom', 'recvmsg', 'send', 'sendmsg', 'sendto',
+ 'setsockopt', 'shutdown', 'socket', 'socketpair',
+}
# /* Protocol families. */
# #define PF_UNSPEC 0 /* Unspecified. */
@@ -68,37 +76,46 @@ SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname",
# #define PF_KEY 15 /* PF_KEY key management API. */
# #define PF_NETLINK 16
-ArgInspectionEntry = namedtuple("ArgInspectionEntry", "arg_index value_set")
+ArgInspectionEntry = collections.namedtuple('ArgInspectionEntry',
+ ('arg_index', 'value_set'))
-def usage(argv):
- print "%s <trace file> [trace files...]" % argv[0]
+def get_parser():
+ """Return a CLI parser for this tool."""
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument('traces', nargs='+', help='The strace logs.')
+ return parser
-def main(traces):
- syscalls = defaultdict(int)
+def main(argv):
+ parser = get_parser()
+ opts = parser.parse_args(argv)
+
+ syscalls = collections.defaultdict(int)
uses_socketcall = False
- basic_set = ["restart_syscall", "exit", "exit_group",
- "rt_sigreturn"]
- frequent_set = []
+ basic_set = [
+ 'restart_syscall', 'exit', 'exit_group', 'rt_sigreturn',
+ ]
syscall_sets = {}
- syscall_set_list = [["sigreturn", "rt_sigreturn"],
- ["sigaction", "rt_sigaction"],
- ["sigprocmask", "rt_sigprocmask"],
- ["open", "openat"],
- ["mmap", "mremap"],
- ["mmap2", "mremap"]]
+ syscall_set_list = [
+ ['sigreturn', 'rt_sigreturn'],
+ ['sigaction', 'rt_sigaction'],
+ ['sigprocmask', 'rt_sigprocmask'],
+ ['open', 'openat'],
+ ['mmap', 'mremap'],
+ ['mmap2', 'mremap'],
+ ]
arg_inspection = {
- "socket": ArgInspectionEntry(0, set([])), # int domain
- "ioctl": ArgInspectionEntry(1, set([])), # int request
- "prctl": ArgInspectionEntry(0, set([])), # int option
- "mmap": ArgInspectionEntry(2, set([])), # int prot
- "mmap2": ArgInspectionEntry(2, set([])), # int prot
- "mprotect": ArgInspectionEntry(2, set([])), # int prot
+ 'socket': ArgInspectionEntry(0, set([])), # int domain
+ 'ioctl': ArgInspectionEntry(1, set([])), # int request
+ 'prctl': ArgInspectionEntry(0, set([])), # int option
+ 'mmap': ArgInspectionEntry(2, set([])), # int prot
+ 'mmap2': ArgInspectionEntry(2, set([])), # int prot
+ 'mprotect': ArgInspectionEntry(2, set([])), # int prot
}
for syscall_list in syscall_set_list:
@@ -107,9 +124,9 @@ def main(traces):
other_syscalls.remove(syscall)
syscall_sets[syscall] = other_syscalls
- for trace_filename in traces:
- if "i386" in trace_filename or ("x86" in trace_filename and
- "64" not in trace_filename):
+ for trace_filename in opts.traces:
+ if 'i386' in trace_filename or ('x86' in trace_filename and
+ '64' not in trace_filename):
uses_socketcall = True
trace_file = open(trace_filename)
@@ -120,7 +137,7 @@ def main(traces):
syscall, args = matches.groups()
if uses_socketcall and syscall in SOCKETCALLS:
- syscall = "socketcall"
+ syscall = 'socketcall'
syscalls[syscall] += 1
@@ -130,17 +147,16 @@ def main(traces):
arg_value = args[arg_inspection[syscall].arg_index]
arg_inspection[syscall].value_set.add(arg_value)
- sorted_syscalls = list(zip(*sorted(syscalls.iteritems(),
- key=lambda pair: pair[1],
- reverse=True))[0])
-
- print NOTICE
+ # Sort the syscalls based on frequency. This way the calls that are used
+ # more often come first which in turn speeds up the filter slightly.
+ sorted_syscalls = list(
+ x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[1],
+ reverse=True)
+ )
- # Add frequent syscalls first.
- for frequent_syscall in frequent_set:
- sorted_syscalls.remove(frequent_syscall)
+ print(NOTICE)
- all_syscalls = frequent_set + sorted_syscalls
+ all_syscalls = sorted_syscalls
# Add the basic set once the frequency drops below 2.
below_ten_index = -1
@@ -160,16 +176,12 @@ def main(traces):
if syscall in arg_inspection:
arg_index = arg_inspection[syscall].arg_index
arg_values = arg_inspection[syscall].value_set
- arg_filter = " || ".join(["arg%d == %s" % (arg_index, arg_value)
- for arg_value in arg_values])
- print syscall + ": " + arg_filter
+ arg_filter = ' || '.join('arg%d == %s' % (arg_index, arg_value)
+ for arg_value in arg_values)
+ print(syscall + ': ' + arg_filter)
else:
- print ALLOW % syscall
-
+ print(ALLOW % syscall)
-if __name__ == "__main__":
- if len(sys.argv) < 2:
- usage(sys.argv)
- sys.exit(1)
- main(sys.argv[1:])
+if __name__ == '__main__':
+ sys.exit(main(sys.argv[1:]))