root/scapy/automaton.py

Revision 862:52d055e522bd, 12.2 kB (checked in by Phil <phil@secdev.org>, 5 months ago)

Merged with one-file scapy

Line 
1 ## This file is part of Scapy
2 ## See http://www.secdev.org/projects/scapy for more informations
3 ## Copyright (C) Philippe Biondi <phil@secdev.org>
4 ## This program is published under a GPLv2 license
5
6 import types,itertools,time
7 from select import select
8 from config import conf
9 from utils import do_graph
10 from error import log_interactive
11 from plist import PacketList
12 from data import MTU
13
14 ##############
15 ## Automata ##
16 ##############
17
18 class ATMT:
19     STATE = "State"
20     ACTION = "Action"
21     CONDITION = "Condition"
22     RECV = "Receive condition"
23     TIMEOUT = "Timeout condition"
24
25     class NewStateRequested(Exception):
26         def __init__(self, state_func, automaton, *args, **kargs):
27             self.func = state_func
28             self.state = state_func.atmt_state
29             self.initial = state_func.atmt_initial
30             self.error = state_func.atmt_error
31             self.final = state_func.atmt_final
32             Exception.__init__(self, "Request state [%s]" % self.state)
33             self.automaton = automaton
34             self.args = args
35             self.kargs = kargs
36             self.action_parameters() # init action parameters
37         def action_parameters(self, *args, **kargs):
38             self.action_args = args
39             self.action_kargs = kargs
40             return self
41         def run(self):
42             return self.func(self.automaton, *self.args, **self.kargs)
43
44     @staticmethod
45     def state(initial=0,final=0,error=0):
46         def deco(f,initial=initial, final=final):
47             f.atmt_type = ATMT.STATE
48             f.atmt_state = f.func_name
49             f.atmt_initial = initial
50             f.atmt_final = final
51             f.atmt_error = error
52             def state_wrapper(self, *args, **kargs):
53                 return ATMT.NewStateRequested(f, self, *args, **kargs)
54
55             state_wrapper.func_name = "%s_wrapper" % f.func_name
56             state_wrapper.atmt_type = ATMT.STATE
57             state_wrapper.atmt_state = f.func_name
58             state_wrapper.atmt_initial = initial
59             state_wrapper.atmt_final = final
60             state_wrapper.atmt_error = error
61             state_wrapper.atmt_origfunc = f
62             return state_wrapper
63         return deco
64     @staticmethod
65     def action(cond, prio=0):
66         def deco(f,cond=cond):
67             if not hasattr(f,"atmt_type"):
68                 f.atmt_cond = {}
69             f.atmt_type = ATMT.ACTION
70             f.atmt_cond[cond.atmt_condname] = prio
71             return f
72         return deco
73     @staticmethod
74     def condition(state, prio=0):
75         def deco(f, state=state):
76             f.atmt_type = ATMT.CONDITION
77             f.atmt_state = state.atmt_state
78             f.atmt_condname = f.func_name
79             f.atmt_prio = prio
80             return f
81         return deco
82     @staticmethod
83     def receive_condition(state, prio=0):
84         def deco(f, state=state):
85             f.atmt_type = ATMT.RECV
86             f.atmt_state = state.atmt_state
87             f.atmt_condname = f.func_name
88             f.atmt_prio = prio
89             return f
90         return deco
91     @staticmethod
92     def timeout(state, timeout):
93         def deco(f, state=state, timeout=timeout):
94             f.atmt_type = ATMT.TIMEOUT
95             f.atmt_state = state.atmt_state
96             f.atmt_timeout = timeout
97             f.atmt_condname = f.func_name
98             return f
99         return deco
100
101
102 class Automaton_metaclass(type):
103     def __new__(cls, name, bases, dct):
104         cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
105         cls.states={}
106         cls.state = None
107         cls.recv_conditions={}
108         cls.conditions={}
109         cls.timeout={}
110         cls.actions={}
111         cls.initial_states=[]
112
113         members = {}
114         classes = [cls]
115         while classes:
116             c = classes.pop(0) # order is important to avoid breaking method overloading
117             classes += list(c.__bases__)
118             for k,v in c.__dict__.iteritems():
119                 if k not in members:
120                     members[k] = v
121
122         decorated = [v for v in members.itervalues()
123                      if type(v) is types.FunctionType and hasattr(v, "atmt_type")]
124        
125         for m in decorated:
126             if m.atmt_type == ATMT.STATE:
127                 s = m.atmt_state
128                 cls.states[s] = m
129                 cls.recv_conditions[s]=[]
130                 cls.conditions[s]=[]
131                 cls.timeout[s]=[]
132                 if m.atmt_initial:
133                     cls.initial_states.append(m)
134             elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT]:
135                 cls.actions[m.atmt_condname] = []
136    
137         for m in decorated:
138             if m.atmt_type == ATMT.CONDITION:
139                 cls.conditions[m.atmt_state].append(m)
140             elif m.atmt_type == ATMT.RECV:
141                 cls.recv_conditions[m.atmt_state].append(m)
142             elif m.atmt_type == ATMT.TIMEOUT:
143                 cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
144             elif m.atmt_type == ATMT.ACTION:
145                 for c in m.atmt_cond:
146                     cls.actions[c].append(m)
147            
148
149         for v in cls.timeout.itervalues():
150             v.sort(lambda (t1,f1),(t2,f2): cmp(t1,t2))
151             v.append((None, None))
152         for v in itertools.chain(cls.conditions.itervalues(),
153                                  cls.recv_conditions.itervalues()):
154             v.sort(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio))
155         for condname,actlst in cls.actions.iteritems():
156             actlst.sort(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname]))
157
158         return cls
159
160        
161     def graph(self, **kargs):
162         s = 'digraph "%s" {\n'  % self.__class__.__name__
163        
164         se = "" # Keep initial nodes at the begining for better rendering
165         for st in self.states.itervalues():
166             if st.atmt_initial:
167                 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
168             elif st.atmt_final:
169                 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
170             elif st.atmt_error:
171                 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
172         s += se
173
174         for st in self.states.values():
175             for n in st.atmt_origfunc.func_code.co_names+st.atmt_origfunc.func_code.co_consts:
176                 if n in self.states:
177                     s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
178            
179
180         for c,k,v in [("purple",k,v) for k,v in self.conditions.items()]+[("red",k,v) for k,v in self.recv_conditions.items()]:
181             for f in v:
182                 for n in f.func_code.co_names+f.func_code.co_consts:
183                     if n in self.states:
184                         l = f.atmt_condname
185                         for x in self.actions[f.atmt_condname]:
186                             l += "\\l>[%s]" % x.func_name
187                         s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
188         for k,v in self.timeout.iteritems():
189             for t,f in v:
190                 if f is None:
191                     continue
192                 for n in f.func_code.co_names+f.func_code.co_consts:
193                     if n in self.states:
194                         l = "%s/%.1fs" % (f.atmt_condname,t)                       
195                         for x in self.actions[f.atmt_condname]:
196                             l += "\\l>[%s]" % x.func_name
197                         s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
198         s += "}\n"
199         return do_graph(s, **kargs)
200        
201
202
203 class Automaton:
204     __metaclass__ = Automaton_metaclass
205
206     def __init__(self, *args, **kargs):
207         self.debug_level=0
208         self.init_args=args
209         self.init_kargs=kargs
210         self.parse_args(*args, **kargs)
211
212     def debug(self, lvl, msg):
213         if self.debug_level >= lvl:
214             log_interactive.debug(msg)
215            
216
217
218
219     class ErrorState(Exception):
220         def __init__(self, msg, result=None):
221             Exception.__init__(self, msg)
222             self.result = result
223     class Stuck(ErrorState):
224         pass
225
226     def parse_args(self, debug=0, store=1, **kargs):
227         self.debug_level=debug
228         self.socket_kargs = kargs
229         self.store_packets = store
230        
231
232     def master_filter(self, pkt):
233         return True
234
235     def run_condition(self, cond, *args, **kargs):
236         try:
237             cond(self,*args, **kargs)
238         except ATMT.NewStateRequested, state_req:
239             self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
240             if cond.atmt_type == ATMT.RECV:
241                 self.packets.append(args[0])
242             for action in self.actions[cond.atmt_condname]:
243                 self.debug(2, "   + Running action [%s]" % action.func_name)
244                 action(self, *state_req.action_args, **state_req.action_kargs)
245             raise
246         else:
247             self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
248            
249
250     def run(self, *args, **kargs):
251         # Update default parameters
252         a = args+self.init_args[len(args):]
253         k = self.init_kargs
254         k.update(kargs)
255         self.parse_args(*a,**k)
256
257         # Start the automaton
258         self.state=self.initial_states[0](self)
259         self.send_sock = conf.L3socket()
260         l = conf.L2listen(**self.socket_kargs)
261         self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
262         while 1:
263             try:
264                 self.debug(1, "## state=[%s]" % self.state.state)
265
266                 # Entering a new state. First, call new state function
267                 state_output = self.state.run()
268                 if self.state.error:
269                     raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output)
270                 if self.state.final:
271                     return state_output
272
273                 if state_output is None:
274                     state_output = ()
275                 elif type(state_output) is not list:
276                     state_output = state_output,
277                
278                 # Then check immediate conditions
279                 for cond in self.conditions[self.state.state]:
280                     self.run_condition(cond, *state_output)
281
282                 # If still there and no conditions left, we are stuck!
283                 if ( len(self.recv_conditions[self.state.state]) == 0
284                      and len(self.timeout[self.state.state]) == 1 ):
285                     raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
286
287                 # Finally listen and pay attention to timeouts
288                 expirations = iter(self.timeout[self.state.state])
289                 next_timeout,timeout_func = expirations.next()
290                 t0 = time.time()
291                
292                 while 1:
293                     t = time.time()-t0
294                     if next_timeout is not None:
295                         if next_timeout <= t:
296                             self.run_condition(timeout_func, *state_output)
297                             next_timeout,timeout_func = expirations.next()
298                     if next_timeout is None:
299                         remain = None
300                     else:
301                         remain = next_timeout-t
302    
303                     r,_,_ = select([l],[],[],remain)
304                     if l in r:
305                         pkt = l.recv(MTU)
306                         if pkt is not None:
307                             if self.master_filter(pkt):
308                                 self.debug(3, "RECVD: %s" % pkt.summary())
309                                 for rcvcond in self.recv_conditions[self.state.state]:
310                                     self.run_condition(rcvcond, pkt, *state_output)
311                             else:
312                                 self.debug(4, "FILTR: %s" % pkt.summary())
313
314             except ATMT.NewStateRequested,state_req:
315                 self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
316                 self.state = state_req
317             except KeyboardInterrupt:
318                 self.debug(1,"Interrupted by user")
319                 break
320
321     def my_send(self, pkt):
322         self.send_sock.send(pkt)
323
324     def send(self, pkt):
325         self.my_send(pkt)
326         self.debug(3,"SENT : %s" % pkt.summary())
327         self.packets.append(pkt.copy())
328
329
330        
Note: See TracBrowser for help on using the browser.