from ryu.base import app_manager from ryu.lib import ofctl_v1_3 from ryu.lib.packet import ether_types from ryu.lib.packet import ethernet from ryu.ofproto import ofproto_v1_3 from ryu.controller.handler import set_ev_cls from ryu.controller.handler import CONFIG_DISPATCHER, MAIN_DISPATCHER from ryu.controller import ofp_event from ryu.lib.packet import packet from tkinter import ttk import tkinter import json import os class Firewall: def __init__(self, datapath, rule_path="wallrules.json", *args, **kwargs): super(Firewall, self).__init__(*args, **kwargs) self.datapath = datapath self._rule_path = rule_path self.rules = dict() self._rules_age = 0 self.are_rules_changed = False try: self.load_rules() except BaseException: self.save_rules() def save_rules(self): f = open(self._rule_path, 'w') self._fix_rules() f.write(json.dumps(self.rules, indent=4, sort_keys=True)) f.close() self.are_rules_changed = True self._update_rules_age() def _update_rules_age(self): self._rules_age = os.path.getmtime(self._rule_path) def load_rules(self): if self._check_for_disk_updates(): f = open(self._rule_path) self.rules = json.loads(f.read()) self._fix_rules() f.close() self.are_rules_changed = True self._update_rules_age() def _fix_rules(self): self.rules['default_action'] = self.rules.get('default_action', True) self.rules['static_vlans'] = dict(map( lambda kp: (int(kp[0]), kp[1]), self.rules.get('static_vlans', dict()).items() )) self.rules['static_vlans'] = dict(filter( lambda kp: len(kp[1]) > 0, self.rules.get('static_vlans', dict()).items() )) self.rules['special_protocols'] = self.rules.get( 'special_protocols', list()) self.rules['special_ports'] = self.rules.get('special_ports', list()) def _check_for_disk_updates(self): return self._rules_age < os.path.getmtime(self._rule_path) def get_vlanid(self, mac, ip): for svlanid, svlanmembers in self.rules.get('static_vlans', dict()).items(): svlanid = int(svlanid) if (mac in svlanmembers) or (ip is not None and ip in svlanmembers): return svlanid return 0 def actions_for_new_flow(self, event, message, datapath, ofproto, eth, parser, allow_action, disallow_action=[]): self.load_rules() default_action = allow_action if self.rules.get( 'default_action', True) else disallow_action protocols = list(packet.Packet(message.data).protocols) ipvx = [protocol for protocol in protocols if protocol.__class__.__name__ in [ 'ipv4', 'ipv6']] ipvx = ipvx[0] if len(ipvx) > 0 else None tcudp = [protocol for protocol in protocols if protocol.__class__.__name__ in [ 'tcp', 'udp']] tcudp = tcudp[0] if len(tcudp) > 0 else None vlanid_src = self.get_vlanid( eth.src, ipvx.src if ipvx is not None else None) vlanid_dst = self.get_vlanid( eth.dst, ipvx.dst if ipvx is not None else None) # devices banned from the network if vlanid_src < 0 or vlanid_dst < 0: return (disallow_action,) # different vlans if vlanid_src != vlanid_dst: return (disallow_action,) # special protocol on first flow for rule in self.rules.get('special_protocols', list()): match_action = allow_action if rule.get( 'action', True) else disallow_action if rule['protocol'].lower() in [protocol.__class__.__name__.lower() for protocol in protocols]: return (match_action,) # special tcp or udp port on first flow if tcudp is not None: for rule in self.rules.get('special_ports', list()): match_action = allow_action if rule.get( 'action', True) else disallow_action ports = [*([tcudp.src_port] if rule.get('src_matters') else []), *([tcudp.dst_port] if rule.get('dst_matters') else [])] matchrange = list( range(rule.get('portrange_start', 0), rule.get('portrange_stop', 65535)+1)) for port in ports: if port in matchrange: return (match_action, False) return (default_action,) def manage_from_gui(self): FirewallGuiManager(self).run() class FirewallGuiManager(object): def __init__(self, firewall): self.firewall = firewall self.root = tkinter.Tk() def run(self): self.menu_screen() self.root.mainloop() def clear_screen(self): for item in self.root.grid_slaves(): item.grid_forget() def close(self): self.root.quit() def menu_screen(self): self.clear_screen() tkinter.Label(self.root, text="Firewall" ).grid(column=0, row=0, columnspan=3) tkinter.Button(self.root, text="Blacklist", command=self.blacklist_screen ).grid(column=0, row=1, columnspan=3, sticky="ew") tkinter.Button(self.root, text="vLANs", command=self.vlan_screen ).grid(column=0, row=2, columnspan=3, sticky="ew") tkinter.Button(self.root, text="Protocols", command=self.proto_screen ).grid(column=0, row=3, columnspan=3, sticky="ew") tkinter.Button(self.root, text="Ports", command=self.port_screen ).grid(column=0, row=4, columnspan=3, sticky="ew") tkinter.Label(self.root, text="Default:").grid(column=0, row=5) r1 = tkinter.Radiobutton(self.root, text="Allow", value=True, command=lambda: ( self.firewall.rules.__setitem__( 'default_action', True), self.firewall.save_rules())) r2 = tkinter.Radiobutton(self.root, text="Block", value=False, command=lambda: ( self.firewall.rules.__setitem__( 'default_action', False), self.firewall.save_rules())) r1.grid(column=1, row=5) r2.grid(column=2, row=5) if self.firewall.rules.get('default_action', True): r1.select() r2.deselect() else: r1.deselect() r2.select() tkinter.Button(self.root, text="Quit", command=self.close ).grid(column=0, row=6, columnspan=3, sticky="ew") def _moveup(self, lst, ndxs): ndxs = list(reversed(sorted(ndxs))) if len(ndxs) <= 0: return for ndx in ndxs: if ndx <= 0: break if ndx >= len(lst): break t = lst[ndx] lst[ndx] = lst[ndx-1] lst[ndx-1] = t return ndxs def _movedown(self, lst, ndxs): ndxs = list(sorted(ndxs)) if len(ndxs) <= 0: return for ndx in ndxs: if ndx < 0: break if ndx+1 >= len(lst): break t = lst[ndx] lst[ndx] = lst[ndx+1] lst[ndx+1] = t return ndxs def blacklist_screen(self, group=-1, title="Blacklist", parent=None, selected=()): if parent is None: parent = self.menu_screen self.clear_screen() self.firewall.rules['static_vlans'][group] = self.firewall.rules['static_vlans'].get(group, list()) rules = self.firewall.rules['static_vlans'][group] tkinter.Label(self.root, text=title ).grid(column=0, row=0, columnspan=2) tree = tkinter.ttk.Treeview( self.root, show='headings', columns=('ipmac')) tree.heading('ipmac', text='MAC/IPv4/IPv6') tree.grid(column=0, row=2, columnspan=2, sticky="ew") id2pos = dict() pos2id = dict() for ndx, item in enumerate(rules): iid = tree.insert('', 'end', values=item) id2pos[iid] = ndx pos2id[ndx] = iid for ndx in range(len(rules)): if ndx in selected: tree.selection_add(pos2id[ndx]) tkinter.Button(self.root, text='Move up', command=lambda: (self.blacklist_screen(group, title, parent, [i-1 for i in self._moveup( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=0, row=3, sticky="ew") tkinter.Button(self.root, text='Move down', command=lambda: (self.blacklist_screen(group, title, parent, [i+1 for i in self._movedown( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=1, row=3, sticky="ew") tkinter.Button(self.root, text='Add', command=lambda: self._addipmac( rules, lambda: self.blacklist_screen(group, title, parent)) ).grid(column=0, row=4, sticky="ew") tkinter.Button(self.root, text='Remove', command=lambda: ( [rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])], self.firewall.save_rules(), self.blacklist_screen(group, title, parent) ) ).grid(column=1, row=4, sticky="ew") tkinter.Button(self.root, text="<", command=parent ).grid(column=0, row=0, sticky="nw") def _addipmac(self, lst, callback): self.clear_screen() tkinter.Label(self.root, text="Add MAC/IPv4/IPv6" ).grid(column=0, row=0) tkinter.Button(self.root, text="<", command=callback ).grid(column=0, row=0, sticky="nw") entry = tkinter.Entry(self.root) entry.grid(column=0, row=1, sticky="ew") tkinter.Button(self.root, text="Save", command=lambda: (lst.append(entry.get()), self.firewall.save_rules(), callback()) ).grid(column=0, row=2, sticky="e") def _addvlanid(self): def validate_and_proceed(entry, vlanrules, next): val = 1 try: val = int(entry.get()) except BaseException: while val in vlanrules: val+=1 entry.delete(0, tkinter.END) entry.insert(0, str(val)) return next(val, 'vLAN #%d' % val, self.vlan_screen) self.clear_screen() tkinter.Label(self.root, text="Add vLAN" ).grid(column=0, row=0) tkinter.Button(self.root, text="<", command=self.vlan_screen ).grid(column=0, row=0, sticky="nw") entry = tkinter.Entry(self.root) entry.grid(column=0, row=1, sticky="ew") tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed(entry, self.firewall.rules['static_vlans'], self.blacklist_screen) ).grid(column=0, row=2, sticky="e") def _addprotorule(self): def validate_and_proceed(btn, entry, lst, next): prt = entry.get().strip() if len(prt) <= 0: entry.focus() return lst.append({ 'action': bool(btn.get()), 'protocol': entry.get().lower(), }) next() self.clear_screen() tkinter.Label(self.root, text="Add protocol rule" ).grid(column=0, row=0) tkinter.Button(self.root, text="<", command=self.proto_screen ).grid(column=0, row=0, sticky="nw") btnVar = tkinter.IntVar() btn = tkinter.Checkbutton(self.root, text="Allow this protocol to flow", variable=btnVar) btn.grid(column=0, row=1) entry = tkinter.Entry(self.root) entry.grid(column=0, row=2, sticky="ew") tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed(btnVar, entry, self.firewall.rules['special_protocols'], lambda: (self.firewall.save_rules(), self.proto_screen())) ).grid(column=0, row=3, sticky="e") def _addportrule(self): def validate_and_proceed(btn1, btn2, btn3, entry1, entry2, lst, next): ps = None try: ps = int(entry1.get()) except BaseException: entry1.delete(0, tkinter.END) entry1.insert(0, '') entry1.focus() return pe = None try: pe = int(entry2.get()) except BaseException: entry2.delete(0, tkinter.END) entry2.insert(0, '') entry2.focus() return if ps > pe: entry1.delete(0, tkinter.END) entry1.insert(0, str(pe)) entry2.delete(0, tkinter.END) entry2.insert(0, str(ps)) return if not (bool(btn2[1].get()) or bool(btn3[1].get())): btn2[0].select() btn3[0].select() return lst.append({ 'action': bool(btn1[1].get()), 'src_matters': bool(btn2[1].get()), 'dst_matters': bool(btn3[1].get()), 'portrange_start': ps, 'portrange_stop': pe, }) next() self.clear_screen() tkinter.Label(self.root, text="Add port range rule" ).grid(column=0, row=0, columnspan=2) tkinter.Button(self.root, text="<", command=self.port_screen ).grid(column=0, row=0, columnspan=2, sticky="nw") btnVar1 = tkinter.IntVar() btnVar2 = tkinter.IntVar() btnVar3 = tkinter.IntVar() btn1 = tkinter.Checkbutton(self.root, text="Allow this port range to flow", variable=btnVar1) btn2 = tkinter.Checkbutton(self.root, text="Matches at packet's port source field", variable=btnVar2) btn3 = tkinter.Checkbutton(self.root, text="Matches at packet's port destination field", variable=btnVar3) btn1.grid(column=0, row=1, columnspan=2) btn2.grid(column=0, row=4, columnspan=2) btn3.grid(column=0, row=5, columnspan=2) tkinter.Label(self.root, text="Port range start").grid(column=0, row=2, sticky="e") tkinter.Label(self.root, text="Port range end").grid(column=0, row=3, sticky="e") entry1 = tkinter.Entry(self.root) entry1.grid(column=1, row=2, sticky="ew") entry2 = tkinter.Entry(self.root) entry2.grid(column=1, row=3, sticky="ew") tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed((btn1, btnVar1), (btn2, btnVar2), (btn3, btnVar3), entry1, entry2, self.firewall.rules['special_ports'], lambda: (self.firewall.save_rules(), self.port_screen())) ).grid(column=0, row=6, columnspan=2, sticky="e") btn2.select() btn3.select() def vlan_screen(self): self.clear_screen() vlans = self.firewall.rules['static_vlans'] tkinter.Label(self.root, text="vLAN list" ).grid(column=0, row=0, columnspan=2) tkinter.Button(self.root, text="<", command=self.menu_screen ).grid(column=0, row=0, columnspan=2, sticky="nw") tree = tkinter.ttk.Treeview( self.root, show='headings', columns=('id', 'nor')) tree.heading('id', text='ID') tree.heading('nor', text='Number of Rules') tree.grid(column=0, row=2, columnspan=2, sticky="ew") id2pos = dict() pos2id = dict() for idno, lst in vlans.items(): if idno <= 0: continue iid = tree.insert('', 'end', values=(idno, len(lst))) id2pos[iid] = idno pos2id[idno] = iid def open4edit(ids): if len(ids)>0: self.blacklist_screen(ids[0], 'vLAN #%d' % ids[0], self.vlan_screen) tkinter.Button(self.root, text='Edit', command=lambda: open4edit([id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]) ).grid(column=0, row=3, sticky="ew") tkinter.Button(self.root, text='Create', command=self._addvlanid ).grid(column=1, row=3, sticky="ew") tkinter.Label(self.root, text="vLAN list" ).grid(column=0, row=0, columnspan=2) tkinter.Button(self.root, text="<", command=self.menu_screen ).grid(column=0, row=0, columnspan=2, sticky="nw") def proto_screen(self, selected=()): self.clear_screen() tkinter.Label(self.root, text="Protocols" ).grid(column=0, row=0, columnspan=2) tkinter.Button(self.root, text="<", command=self.menu_screen ).grid(column=0, row=0, columnspan=2, sticky="nw") tree = tkinter.ttk.Treeview( self.root, show='headings', columns=('seq', 'protocol', 'action')) tree.heading('seq', text='#') tree.heading('action', text='Action') tree.heading('protocol', text='Protocol') tree.grid(column=0, row=2, columnspan=2, sticky="ew") self.firewall.rules['special_protocols'] = self.firewall.rules.get('special_protocols', list()) rules = self.firewall.rules['special_protocols'] id2pos = dict() pos2id = dict() for ndx, item in enumerate(rules): iid = tree.insert('', 'end', values=(ndx+1, item['protocol'].upper(), 'Allow' if item['action'] else 'Drop')) id2pos[iid] = ndx pos2id[ndx] = iid for ndx in range(len(rules)): if ndx in selected: tree.selection_add(pos2id[ndx]) tkinter.Button(self.root, text='Move up', command=lambda: (self.proto_screen([i-1 for i in self._moveup( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=0, row=3, sticky="ew") tkinter.Button(self.root, text='Move down', command=lambda: (self.proto_screen([i+1 for i in self._movedown( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=1, row=3, sticky="ew") tkinter.Button(self.root, text='Add', command=self._addprotorule ).grid(column=0, row=4, sticky="ew") tkinter.Button(self.root, text='Remove', command=lambda: ( [rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])], self.firewall.save_rules(), self.proto_screen() ) ).grid(column=1, row=4, sticky="ew") def port_screen(self, selected=()): self.clear_screen() tkinter.Label(self.root, text="Ports" ).grid(column=0, row=0, columnspan=2) tkinter.Button(self.root, text="<", command=self.menu_screen ).grid(column=0, row=0, columnspan=2, sticky="nw") tree = tkinter.ttk.Treeview( self.root, show='headings', columns=('seq', 'portrange', 'srcdst', 'action')) tree.heading('seq', text='#') tree.heading('action', text='Action') tree.heading('srcdst', text='Source/Destination') tree.heading('portrange', text='Port range') tree.grid(column=0, row=2, columnspan=2, sticky="ew") self.firewall.rules['special_ports'] = self.firewall.rules.get('special_ports', list()) rules = self.firewall.rules['special_ports'] id2pos = dict() pos2id = dict() for ndx, item in enumerate(rules): iid = tree.insert('', 'end', values=( ndx+1, str(item['portrange_start']) if item['portrange_start']==item['portrange_stop'] else "%d .. %d"%(item['portrange_start'], item['portrange_stop']), 'Both source and destination' if item['src_matters'] and item['dst_matters'] else ('Source' if item['src_matters'] else 'Destination'), 'Allow' if item['action'] else 'Drop' )) id2pos[iid] = ndx pos2id[ndx] = iid for ndx in range(len(rules)): if ndx in selected: tree.selection_add(pos2id[ndx]) tkinter.Button(self.root, text='Move up', command=lambda: (self.port_screen([i-1 for i in self._moveup( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=0, row=3, sticky="ew") tkinter.Button(self.root, text='Move down', command=lambda: (self.port_screen([i+1 for i in self._movedown( rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])] )]), self.firewall.save_rules()) ).grid(column=1, row=3, sticky="ew") tkinter.Button(self.root, text='Add', command=self._addportrule ).grid(column=0, row=4, sticky="ew") tkinter.Button(self.root, text='Remove', command=lambda: ( [rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])], self.firewall.save_rules(), self.port_screen() ) ).grid(column=1, row=4, sticky="ew") # Copyright (C) 2011 Nippon Telegraph and Telephone Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. class SimpleSwitch13(app_manager.RyuApp): OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION] def __init__(self, *args, **kwargs): super(SimpleSwitch13, self).__init__(*args, **kwargs) self.mac_to_port = {} self.firewalls = dict() @set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER) def switch_features_handler(self, ev): datapath = ev.msg.datapath ofproto = datapath.ofproto parser = datapath.ofproto_parser # install table-miss flow entry # # We specify NO BUFFER to max_len of the output action due to # OVS bug. At this moment, if we specify a lesser number, e.g., # 128, OVS will send Packet-In with invalid buffer_id and # truncated packet data. In that case, we cannot output packets # correctly. The bug has been fixed in OVS v2.1.0. match = parser.OFPMatch() actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER, ofproto.OFPCML_NO_BUFFER)] self.add_flow(datapath, 0, match, actions) def add_flow(self, datapath, priority, match, actions, buffer_id=None, idle_timeout=0, hard_timeout=0): ofproto = datapath.ofproto parser = datapath.ofproto_parser inst = [parser.OFPInstructionActions(ofproto.OFPIT_APPLY_ACTIONS, actions)] if buffer_id: mod = parser.OFPFlowMod(datapath=datapath, buffer_id=buffer_id, priority=priority, match=match, instructions=inst, idle_timeout=idle_timeout, hard_timeout=hard_timeout) else: mod = parser.OFPFlowMod(datapath=datapath, priority=priority, match=match, instructions=inst, idle_timeout=idle_timeout, hard_timeout=hard_timeout) datapath.send_msg(mod) @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER) def _packet_in_handler(self, ev): # If you hit this you might want to increase # the "miss_send_length" of your switch if ev.msg.msg_len < ev.msg.total_len: self.logger.debug("packet truncated: only %s of %s bytes", ev.msg.msg_len, ev.msg.total_len) msg = ev.msg datapath = msg.datapath ofproto = datapath.ofproto parser = datapath.ofproto_parser in_port = msg.match['in_port'] pkt = packet.Packet(msg.data) eth = pkt.get_protocols(ethernet.ethernet)[0] if eth.ethertype == ether_types.ETH_TYPE_LLDP: # ignore lldp packet return dst = eth.dst src = eth.src dpid = datapath.id self.mac_to_port.setdefault(dpid, {}) self.logger.info("packet in %s %s %s %s", dpid, src, dst, in_port) # learn a mac address to avoid FLOOD next time. self.mac_to_port[dpid][src] = in_port if dst in self.mac_to_port[dpid]: out_port = self.mac_to_port[dpid][dst] else: out_port = ofproto.OFPP_FLOOD actions = [parser.OFPActionOutput(out_port)] # run firewall on it if datapath.id not in self.firewalls: self.firewalls[datapath.id] = Firewall(datapath) actions, save_in_flow_table, *_ = [ *list(self.firewalls[datapath.id].actions_for_new_flow( event=ev, message=msg, datapath=datapath, ofproto=ofproto, eth=eth, parser=parser, allow_action=actions )), True ] if save_in_flow_table: # install a flow to avoid packet_in next time if out_port != ofproto.OFPP_FLOOD: match = parser.OFPMatch(in_port=in_port, eth_dst=dst) # verify if we have a valid buffer_id, if yes avoid to send both # flow_mod & packet_out if msg.buffer_id != ofproto.OFP_NO_BUFFER: self.add_flow(datapath, 1, match, actions, msg.buffer_id, idle_timeout=30, hard_timeout=90) return else: self.add_flow(datapath, 1, match, actions, idle_timeout=30, hard_timeout=90) data = None if msg.buffer_id == ofproto.OFP_NO_BUFFER: data = msg.data out = parser.OFPPacketOut(datapath=datapath, buffer_id=msg.buffer_id, in_port=in_port, actions=actions, data=data) datapath.send_msg(out) if __name__ == '__main__': Firewall(None).manage_from_gui()