1 #!/usr/bin/python -tO
  2
  3 # transform an iptables firewall to a dot file
  4 #
  5 # marcus fritzsch - http://fritschy.de - 20070620
  6 #
  7 # this program is licensed undr the GNU GPL ver. 2 or any
  8 # later at your option.
  9 #
 10 # some code was shamelessly stolen from pyroman:
 11 # http://pyroman.alioth.debian.org/
 12
 13 import re, sys, os.path
 14
 15 # from 1.3.6. documentation
 16 BUILTIN_TARGETS = set([ 'ACCEPT', 'DROP', 'RETURN', \
 17   'CLASSIFY', 'CLUSTERIP', 'CONNMARK', 'CONNSECMARK', \
 18   'DNAT', 'DSCP', 'ECN', 'IPV4OPTSSTRIP', 'LOG', 'MARK', 'MASQUERADE', \
 19   'MIRROR', 'NETMAP', 'NFQUEUE', 'NOTRACK', 'REDIRECT', 'REJECT', \
 20   'ROUTE', 'SAME', 'SECMARK', 'SET', 'SNAT', 'TARPIT', 'TCPMSS', 'TOS', \
 21   'TRACE', 'TTL', 'ULOG'])
 22
 23 # just in case there were no counters, disable dashed-lines
 24 COUNTER_ERROR = False
 25
 26 def print_msg (out, pfx, *a):
 27   print >> out, "%s: %s: %s" % (os.path.basename (sys.argv[0]), \
 28       pfx, " ".join (map (str, a)))
 29
 30 def print_err (*a):
 31   print_msg (sys.stderr, 'ERROR', *a)
 32
 33 def print_warn (*a):
 34   print_msg (sys.stderr, 'WARNING', *a)
 35
 36 class Counter (object):
 37   def __init__ (self, p = 0, b = 0):
 38     self.packets, self.bytes = p, b
 39   def __str__ (self):
 40     return "[%d:%d]" % (self.packets, self.bytes)
 41   def __iadd__ (self, c):
 42     self.packets += c.packets
 43     self.bytes += c.bytes
 44     return self
 45   @staticmethod
 46   def fromstr (s):
 47     m = re.match (r"^\[(\d+):(\d+)\]$", s)
 48     if m:
 49       return Counter (int (m.group (1)), int (m.group (2)))
 50     return Counter (0, 0) # be nice
 51
 52 class Rule (object):
 53   def __init__ (self, r, t, a, n, c = None):
 54     self.rule, self.target, self.args, self.num, self.counter = \
 55         r or "", t, a or "", n, c or Counter ()
 56     self.rule, self.args = [ \
 57         ''.join ([c in '<>\'"' and '\\'+c or c for c in x]) \
 58         for x in self.rule, self.args]
 59     if not len (self.rule+self.args): self.rule = "[empty]"
 60
 61 class Chain (object):
 62   def __init__ (self, n, p = ""):
 63     self.rules, self.policy, self.name, self.counter = \
 64         list (), p, n, None
 65
 66 class Table (object):
 67   def __init__ (self, n):
 68     self.chains, self.targets, self.name = dict (), dict (), n
 69
 70 def rename (n):
 71   """rename a chain: replace interfering hyphens by underscores"""
 72   return n.replace ('-', '_')
 73
 74 def get_iptables (dump_file = None):
 75   global COUNTER_ERROR
 76   from popen2 import popen3
 77
 78   lines = []
 79
 80   if dump_file:
 81     infile = file (dump_file)
 82     lines = [x.rstrip ().lstrip () for x in infile.readlines ()]
 83     infile.close ()
 84   else:
 85     o, i, e = popen3 ("iptables-save -c")
 86     lines = [x.rstrip ().lstrip () for x in o.readlines ()]
 87     [x.close () for x in (o, e, i)]
 88
 89   tables = dict ()
 90   cur_tbl = None
 91
 92   match_line = re.compile (\
 93       r"^(?:\[(\d+):(\d+)\] )?(?:-A ([^ ]+))(?: (.*))?(?: -j ([^ ]+)(?: (.*))?)$")
 94   PACKETS, BYTES, CHAIN, RULE, TARGET, T_ARGS = range(1,7)
 95
 96   ruleCounter, line_ctr = 1, 0
 97   for line in lines:
 98
 99     line_ctr += 1
100
101     if line [0] == '#': # ignore comments
102       continue
103
104     if line [0] == '*': # new table
105       s = rename (line[1:])
106       cur_tbl = Table (s)
107       continue
108
109     if line [0] == ':': # new chain for current table
110       s = line[1:].split ()
111       assert len (s) in (2, 3)
112       policy, counter = s[1], None
113       if policy == '-': policy = None
114       else: counter = Counter.fromstr (s[2])
115       chain = rename (s[0])
116       cur_tbl.chains [chain] = Chain (chain, policy)
117       cur_tbl.chains [chain].counter = counter
118       continue
119
120     if line == 'COMMIT': # end current table
121       tables [cur_tbl.name] = cur_tbl
122       cur_tbl = None
123       continue
124
125     assert cur_tbl
126
127     # m.group (N): line, packets, bytes, chain, rule, target, target-args
128     m = match_line.match (line)
129     if m:
130       ctr = m.group (PACKETS), m.group (BYTES)
131       if not COUNTER_ERROR and (not ctr[0] or not ctr[1]):
132         print_warn ("counters not available!")
133         COUNTER_ERROR = True
134       ctr = (ctr[0] or '0', ctr[1] or '0')
135       counter = Counter (int (ctr[0]), int (ctr[1]))
136       rule = Rule (m.group (RULE), rename(m.group (TARGET)), \
137           m.group (T_ARGS), ruleCounter, counter)
138       ruleCounter += 1
139       cur_tbl.chains [rename (m.group(CHAIN))].rules.append (rule)
140       cur_tbl.targets [rename (m.group(TARGET))] = \
141           cur_tbl.targets.get(rename (m.group(TARGET)), 0) + 1
142     else:
143       print_err ("line %d could not be parsed: `%s'" % (line_ctr, line))
144   # end for line in lines
145
146   return tables
147
148 def dot_output (table, out, min_line_width, max_line_width):
149   if not len (table.targets):
150     print_warn ("table `%s' is empty, doing nothing" % table.name)
151     return
152
153   def print_dot (*a):
154     out.write (" ".join (a))
155
156   def get_policy (chain):
157     return chain.policy and chain.policy or 'RETURN'
158
159   def record_prologue (chain):
160     print_dot ('  %s [label="%s ' % (chain.name, chain.name))
161
162   def record_epilogue (chain):
163     print_dot ('| ')
164     pol = get_policy (chain)
165     print_dot ("<policy_%s> %s " % (pol, pol))
166     edges ['%s:policy_%s' % (chain.name, pol)] = (pol, chain.counter)
167     print_dot ('"];\n')
168
169   class MangleName (object):
170     def __init__ (self):
171       self.__ctr, self.__lut = 1, dict ()
172     def get (self, chain, rule):
173       name = chain+rule.rule+rule.args+str(rule.num)
174       if not self.__lut.get (name):
175         self.__lut [name] = self.__ctr
176         self.__ctr += 1
177       return 'name_' + str (self.__lut [name])
178   nameMangler = MangleName ()
179
180   print_dot ("digraph %s {\n" % table.name)
181   print_dot ("  rankdir=LR;\n")
182   print_dot ("  edge [splines=true];\n")
183   print_dot ("  node [shape=record];\n\n")
184
185   edges = {} # needed at the end to draw all edges
186   targets = {} # needed to identify not yet drawn targets
187   displayed = {}
188
189   # precompute counters, i.e. add all rules counters to their target
190   counters, max_bytes = {}, 0
191   for chain in table.chains.itervalues ():
192     for rule in chain.rules:
193       if not counters.get (rule.target): counters [rule.target] = Counter ()
194       counters [rule.target] += rule.counter
195
196   for chain in table.chains.itervalues ():
197     targets [get_policy (chain)] = targets.get (get_policy (chain), 0) + 1
198
199     # fix user defined chains with the precomputed counter
200     if not chain.counter:
201       if counters.get (chain.name):
202         chain.counter = counters [chain.name]
203       else:
204         chain.counter = Counter () # last escape, i.e. unreferenced chains
205     max_bytes = max (chain.counter.bytes, max_bytes)
206
207     if not len (chain.rules):
208       continue
209
210     record_prologue (chain)
211     displayed [chain.name] = 1
212     for rule in chain.rules:
213       max_bytes = max (rule.counter.bytes, max_bytes)
214       print_dot ('| ')
215       print_dot ('<%s> %s ' % (nameMangler.get (chain.name, rule), \
216           rule.rule+(rule.args and " "+rule.args or "")))
217       edges ['%s:%s' % (chain.name, nameMangler.get (chain.name, rule))] = \
218           (rule.target, rule.counter)
219       targets [rule.target] = targets.get (rule.target, 0) + 1
220     record_epilogue (chain)
221
222   # explicitly print nodes not yet printed, including ellipsed
223   # builtin targets
224   for t in targets.iterkeys ():
225     if t in BUILTIN_TARGETS:
226       print_dot ('  %s [label=%s shape=ellipse];\n' % (t, t))
227     elif not displayed.get (t):
228       record_prologue (table.chains [t])
229       record_epilogue (table.chains [t])
230
231   # one more newline...
232   print_dot ('\n')
233
234   import math
235   log_base = math.log (max_bytes) / (max_line_width - min_line_width)
236   for start, end_and_ctr in edges.iteritems ():
237     end, ctr = end_and_ctr
238     style = ''
239     if ctr.bytes == 0 and not COUNTER_ERROR:
240       style += "dashed, setlinewidth(%f)" % min_line_width
241     else:
242       style += "solid, setlinewidth(%f)" % (min_line_width + \
243           math.log (float (ctr.bytes+1)) / log_base)
244     print_dot ('  %s -> %s [style="%s"];\n' % (start, end, style))
245
246   print_dot ('}\n')
247
248 def usage ():
249   print """Usage: %s [-wsch] [ip-tables]
250   -W num    set max line width to num (default: 5)
251   -w num    set zero-traffic line width to num (default: 1)
252   -s str    set the root_name suffix to str (default: none)
253   -c        causes the dot file to be printed to stdout (default: off)
254   -i file   read an iptables-save dump from file
255   -f        force overwriting of existing files
256   -h        show this help message
257
258 example:
259   %s -s _test
260   will save every table to a table_test.dot file
261
262 By default only the filter table will be dotted.
263 The table selection will be set to exactly the ones
264 on the command line if given.""" % (sys.argv[0], sys.argv[0])
265   sys.exit (1)
266
267 if __name__ == '__main__':
268   from getopt import gnu_getopt as getopt
269   import os.path as path
270
271   line_width0, line_width1 = 1, 5
272   to_stdout = force = False
273   root_suffix = ""
274   in_file = None
275
276   opts = getopt (sys.argv [1:], 'w:W:s:i:fch')
277   for opt, arg in opts[0]:
278     if opt == '-w': line_width0 = float (arg)
279     elif opt == '-W': line_width1 = float (arg)
280     elif opt == '-c': to_stdout = True
281     elif opt == '-s': root_suffix = arg
282     elif opt == '-i': in_file = arg
283     elif opt == '-f': force = True
284     elif opt == '-h': usage ()
285   
286   def check_sanity (check, message, exit=True):
287     if check:
288       print_err (message)
289       if exit: sys.exit (1)
290   
291   check_sanity (line_width0 <= 0, "min line width is invalid")
292   check_sanity (line_width1 <= 0, "max line width is invalid")
293   check_sanity (line_width0 >= line_width1, \
294       "min line width > max line width")
295   check_sanity (to_stdout and len (opts[1]) > 1, \
296       "writing more than 1 table to stdout!", False)
297
298   # check file name and look for the force
299   def get_file_name (f):
300     f = path.abspath (f)
301     if path.exists (f):
302       if not force:
303         print_err ("output file %s exists use -f to force execution"%f)
304         return
305     return f
306
307   iptables = {}
308   try:
309     iptables = get_iptables (in_file)
310   except IOError:
311     print_err (sys.exc_info ()[1])
312     sys.exit (1)
313
314   for tbl in len (opts [1]) and opts [1] or ['filter']:
315     try:
316       if to_stdout:
317         dot_output (iptables [tbl], sys.stdout, line_width0, line_width1)
318       else:
319         outfile = get_file_name ("%s%s.dot" % (tbl, root_suffix))
320         if not outfile: continue
321         dot_output (iptables [tbl], file (outfile, 'w'), \
322             line_width0, line_width1)
323     except OSError:
324       print_err (sys.exc_info ()[1])
325     # want to see other errors here