618 lines
27 KiB
Python
Executable File
618 lines
27 KiB
Python
Executable File
from ryu.base import app_manager
|
|
from ryu.lib import ofctl_v1_3
|
|
from ryu.lib.packet import ether_types
|
|
from ryu.lib.packet import ethernet
|
|
from ryu.ofproto import ofproto_v1_3
|
|
from ryu.controller.handler import set_ev_cls
|
|
from ryu.controller.handler import CONFIG_DISPATCHER, MAIN_DISPATCHER
|
|
from ryu.controller import ofp_event
|
|
from ryu.lib.packet import packet
|
|
from tkinter import ttk
|
|
import tkinter
|
|
import json
|
|
import os
|
|
|
|
|
|
class Firewall:
|
|
def __init__(self, datapath, rule_path="wallrules.json", *args, **kwargs):
|
|
super(Firewall, self).__init__(*args, **kwargs)
|
|
self.datapath = datapath
|
|
self._rule_path = rule_path
|
|
self.rules = dict()
|
|
self._rules_age = 0
|
|
self.are_rules_changed = False
|
|
try:
|
|
self.load_rules()
|
|
except BaseException:
|
|
self.save_rules()
|
|
|
|
def save_rules(self):
|
|
f = open(self._rule_path, 'w')
|
|
self._fix_rules()
|
|
f.write(json.dumps(self.rules, indent=4, sort_keys=True))
|
|
f.close()
|
|
self.are_rules_changed = True
|
|
self._update_rules_age()
|
|
|
|
def _update_rules_age(self):
|
|
self._rules_age = os.path.getmtime(self._rule_path)
|
|
|
|
def load_rules(self):
|
|
if self._check_for_disk_updates():
|
|
f = open(self._rule_path)
|
|
self.rules = json.loads(f.read())
|
|
self._fix_rules()
|
|
f.close()
|
|
self.are_rules_changed = True
|
|
self._update_rules_age()
|
|
|
|
def _fix_rules(self):
|
|
self.rules['default_action'] = self.rules.get('default_action', True)
|
|
self.rules['static_vlans'] = dict(map(
|
|
lambda kp: (int(kp[0]), kp[1]),
|
|
self.rules.get('static_vlans', dict()).items()
|
|
))
|
|
self.rules['static_vlans'] = dict(filter(
|
|
lambda kp: len(kp[1]) > 0,
|
|
self.rules.get('static_vlans', dict()).items()
|
|
))
|
|
self.rules['special_protocols'] = self.rules.get(
|
|
'special_protocols', list())
|
|
self.rules['special_ports'] = self.rules.get('special_ports', list())
|
|
|
|
def _check_for_disk_updates(self):
|
|
return self._rules_age < os.path.getmtime(self._rule_path)
|
|
|
|
def get_vlanid(self, mac, ip):
|
|
for svlanid, svlanmembers in self.rules.get('static_vlans', dict()).items():
|
|
svlanid = int(svlanid)
|
|
if (mac in svlanmembers) or (ip is not None and ip in svlanmembers):
|
|
return svlanid
|
|
return 0
|
|
|
|
def actions_for_new_flow(self, event, message, datapath, ofproto, eth, parser, allow_action, disallow_action=[]):
|
|
self.load_rules()
|
|
default_action = allow_action if self.rules.get(
|
|
'default_action', True) else disallow_action
|
|
protocols = list(packet.Packet(message.data).protocols)
|
|
ipvx = [protocol for protocol in protocols if protocol.__class__.__name__ in [
|
|
'ipv4', 'ipv6']]
|
|
ipvx = ipvx[0] if len(ipvx) > 0 else None
|
|
tcudp = [protocol for protocol in protocols if protocol.__class__.__name__ in [
|
|
'tcp', 'udp']]
|
|
tcudp = tcudp[0] if len(tcudp) > 0 else None
|
|
vlanid_src = self.get_vlanid(
|
|
eth.src, ipvx.src if ipvx is not None else None)
|
|
vlanid_dst = self.get_vlanid(
|
|
eth.dst, ipvx.dst if ipvx is not None else None)
|
|
# devices banned from the network
|
|
if vlanid_src < 0 or vlanid_dst < 0:
|
|
return (disallow_action,)
|
|
# different vlans
|
|
if vlanid_src != vlanid_dst:
|
|
return (disallow_action,)
|
|
# special protocol on first flow
|
|
for rule in self.rules.get('special_protocols', list()):
|
|
match_action = allow_action if rule.get(
|
|
'action', True) else disallow_action
|
|
if rule['protocol'].lower() in [protocol.__class__.__name__.lower() for protocol in protocols]:
|
|
return (match_action,)
|
|
# special tcp or udp port on first flow
|
|
if tcudp is not None:
|
|
for rule in self.rules.get('special_ports', list()):
|
|
match_action = allow_action if rule.get(
|
|
'action', True) else disallow_action
|
|
ports = [*([tcudp.src_port] if rule.get('src_matters') else []),
|
|
*([tcudp.dst_port] if rule.get('dst_matters') else [])]
|
|
matchrange = list(
|
|
range(rule.get('portrange_start', 0), rule.get('portrange_stop', 65535)+1))
|
|
for port in ports:
|
|
if port in matchrange:
|
|
return (match_action, False)
|
|
return (default_action,)
|
|
|
|
def manage_from_gui(self):
|
|
FirewallGuiManager(self).run()
|
|
|
|
|
|
class FirewallGuiManager(object):
|
|
def __init__(self, firewall):
|
|
self.firewall = firewall
|
|
self.root = tkinter.Tk()
|
|
|
|
def run(self):
|
|
self.menu_screen()
|
|
self.root.mainloop()
|
|
|
|
def clear_screen(self):
|
|
for item in self.root.grid_slaves():
|
|
item.grid_forget()
|
|
|
|
def close(self):
|
|
self.root.quit()
|
|
|
|
def menu_screen(self):
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Firewall"
|
|
).grid(column=0, row=0, columnspan=3)
|
|
tkinter.Button(self.root, text="Blacklist",
|
|
command=self.blacklist_screen
|
|
).grid(column=0, row=1, columnspan=3, sticky="ew")
|
|
tkinter.Button(self.root, text="vLANs",
|
|
command=self.vlan_screen
|
|
).grid(column=0, row=2, columnspan=3, sticky="ew")
|
|
tkinter.Button(self.root, text="Protocols",
|
|
command=self.proto_screen
|
|
).grid(column=0, row=3, columnspan=3, sticky="ew")
|
|
tkinter.Button(self.root, text="Ports",
|
|
command=self.port_screen
|
|
).grid(column=0, row=4, columnspan=3, sticky="ew")
|
|
tkinter.Label(self.root, text="Default:").grid(column=0, row=5)
|
|
r1 = tkinter.Radiobutton(self.root, text="Allow", value=True,
|
|
command=lambda: (
|
|
self.firewall.rules.__setitem__(
|
|
'default_action', True),
|
|
self.firewall.save_rules()))
|
|
r2 = tkinter.Radiobutton(self.root, text="Block", value=False,
|
|
command=lambda: (
|
|
self.firewall.rules.__setitem__(
|
|
'default_action', False),
|
|
self.firewall.save_rules()))
|
|
r1.grid(column=1, row=5)
|
|
r2.grid(column=2, row=5)
|
|
if self.firewall.rules.get('default_action', True):
|
|
r1.select()
|
|
r2.deselect()
|
|
else:
|
|
r1.deselect()
|
|
r2.select()
|
|
tkinter.Button(self.root, text="Quit",
|
|
command=self.close
|
|
).grid(column=0, row=6, columnspan=3, sticky="ew")
|
|
|
|
def _moveup(self, lst, ndxs):
|
|
ndxs = list(reversed(sorted(ndxs)))
|
|
if len(ndxs) <= 0:
|
|
return
|
|
for ndx in ndxs:
|
|
if ndx <= 0:
|
|
break
|
|
if ndx >= len(lst):
|
|
break
|
|
t = lst[ndx]
|
|
lst[ndx] = lst[ndx-1]
|
|
lst[ndx-1] = t
|
|
return ndxs
|
|
|
|
def _movedown(self, lst, ndxs):
|
|
ndxs = list(sorted(ndxs))
|
|
if len(ndxs) <= 0:
|
|
return
|
|
for ndx in ndxs:
|
|
if ndx < 0:
|
|
break
|
|
if ndx+1 >= len(lst):
|
|
break
|
|
t = lst[ndx]
|
|
lst[ndx] = lst[ndx+1]
|
|
lst[ndx+1] = t
|
|
return ndxs
|
|
|
|
def blacklist_screen(self, group=-1, title="Blacklist", parent=None, selected=()):
|
|
if parent is None:
|
|
parent = self.menu_screen
|
|
self.clear_screen()
|
|
self.firewall.rules['static_vlans'][group] = self.firewall.rules['static_vlans'].get(group, list())
|
|
rules = self.firewall.rules['static_vlans'][group]
|
|
tkinter.Label(self.root, text=title
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tree = tkinter.ttk.Treeview(
|
|
self.root, show='headings', columns=('ipmac'))
|
|
tree.heading('ipmac', text='MAC/IPv4/IPv6')
|
|
tree.grid(column=0, row=2, columnspan=2, sticky="ew")
|
|
id2pos = dict()
|
|
pos2id = dict()
|
|
for ndx, item in enumerate(rules):
|
|
iid = tree.insert('', 'end', values=item)
|
|
id2pos[iid] = ndx
|
|
pos2id[ndx] = iid
|
|
for ndx in range(len(rules)):
|
|
if ndx in selected:
|
|
tree.selection_add(pos2id[ndx])
|
|
tkinter.Button(self.root, text='Move up',
|
|
command=lambda: (self.blacklist_screen(group, title, parent, [i-1 for i in self._moveup(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=0, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Move down',
|
|
command=lambda: (self.blacklist_screen(group, title, parent, [i+1 for i in self._movedown(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=1, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Add',
|
|
command=lambda: self._addipmac(
|
|
rules, lambda: self.blacklist_screen(group, title, parent))
|
|
).grid(column=0, row=4, sticky="ew")
|
|
tkinter.Button(self.root, text='Remove', command=lambda: (
|
|
[rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])],
|
|
self.firewall.save_rules(), self.blacklist_screen(group, title, parent)
|
|
)
|
|
).grid(column=1, row=4, sticky="ew")
|
|
tkinter.Button(self.root, text="<", command=parent
|
|
).grid(column=0, row=0, sticky="nw")
|
|
|
|
def _addipmac(self, lst, callback):
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Add MAC/IPv4/IPv6"
|
|
).grid(column=0, row=0)
|
|
tkinter.Button(self.root, text="<", command=callback
|
|
).grid(column=0, row=0, sticky="nw")
|
|
entry = tkinter.Entry(self.root)
|
|
entry.grid(column=0, row=1, sticky="ew")
|
|
tkinter.Button(self.root, text="Save", command=lambda: (lst.append(entry.get()), self.firewall.save_rules(), callback())
|
|
).grid(column=0, row=2, sticky="e")
|
|
|
|
def _addvlanid(self):
|
|
def validate_and_proceed(entry, vlanrules, next):
|
|
val = 1
|
|
try:
|
|
val = int(entry.get())
|
|
except BaseException:
|
|
while val in vlanrules:
|
|
val+=1
|
|
entry.delete(0, tkinter.END)
|
|
entry.insert(0, str(val))
|
|
return
|
|
next(val, 'vLAN #%d' % val, self.vlan_screen)
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Add vLAN"
|
|
).grid(column=0, row=0)
|
|
tkinter.Button(self.root, text="<", command=self.vlan_screen
|
|
).grid(column=0, row=0, sticky="nw")
|
|
entry = tkinter.Entry(self.root)
|
|
entry.grid(column=0, row=1, sticky="ew")
|
|
tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed(entry, self.firewall.rules['static_vlans'], self.blacklist_screen)
|
|
).grid(column=0, row=2, sticky="e")
|
|
|
|
def _addprotorule(self):
|
|
def validate_and_proceed(btn, entry, lst, next):
|
|
prt = entry.get().strip()
|
|
if len(prt) <= 0:
|
|
entry.focus()
|
|
return
|
|
lst.append({
|
|
'action': bool(btn.get()),
|
|
'protocol': entry.get().lower(),
|
|
})
|
|
next()
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Add protocol rule"
|
|
).grid(column=0, row=0)
|
|
tkinter.Button(self.root, text="<", command=self.proto_screen
|
|
).grid(column=0, row=0, sticky="nw")
|
|
btnVar = tkinter.IntVar()
|
|
btn = tkinter.Checkbutton(self.root, text="Allow this protocol to flow", variable=btnVar)
|
|
btn.grid(column=0, row=1)
|
|
entry = tkinter.Entry(self.root)
|
|
entry.grid(column=0, row=2, sticky="ew")
|
|
tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed(btnVar, entry, self.firewall.rules['special_protocols'], lambda: (self.firewall.save_rules(), self.proto_screen()))
|
|
).grid(column=0, row=3, sticky="e")
|
|
|
|
def _addportrule(self):
|
|
def validate_and_proceed(btn1, btn2, btn3, entry1, entry2, lst, next):
|
|
ps = None
|
|
try:
|
|
ps = int(entry1.get())
|
|
except BaseException:
|
|
entry1.delete(0, tkinter.END)
|
|
entry1.insert(0, '')
|
|
entry1.focus()
|
|
return
|
|
pe = None
|
|
try:
|
|
pe = int(entry2.get())
|
|
except BaseException:
|
|
entry2.delete(0, tkinter.END)
|
|
entry2.insert(0, '')
|
|
entry2.focus()
|
|
return
|
|
if ps > pe:
|
|
entry1.delete(0, tkinter.END)
|
|
entry1.insert(0, str(pe))
|
|
entry2.delete(0, tkinter.END)
|
|
entry2.insert(0, str(ps))
|
|
return
|
|
if not (bool(btn2[1].get()) or bool(btn3[1].get())):
|
|
btn2[0].select()
|
|
btn3[0].select()
|
|
return
|
|
lst.append({
|
|
'action': bool(btn1[1].get()),
|
|
'src_matters': bool(btn2[1].get()),
|
|
'dst_matters': bool(btn3[1].get()),
|
|
'portrange_start': ps,
|
|
'portrange_stop': pe,
|
|
})
|
|
next()
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Add port range rule"
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tkinter.Button(self.root, text="<", command=self.port_screen
|
|
).grid(column=0, row=0, columnspan=2, sticky="nw")
|
|
btnVar1 = tkinter.IntVar()
|
|
btnVar2 = tkinter.IntVar()
|
|
btnVar3 = tkinter.IntVar()
|
|
btn1 = tkinter.Checkbutton(self.root, text="Allow this port range to flow", variable=btnVar1)
|
|
btn2 = tkinter.Checkbutton(self.root, text="Matches at packet's port source field", variable=btnVar2)
|
|
btn3 = tkinter.Checkbutton(self.root, text="Matches at packet's port destination field", variable=btnVar3)
|
|
btn1.grid(column=0, row=1, columnspan=2)
|
|
btn2.grid(column=0, row=4, columnspan=2)
|
|
btn3.grid(column=0, row=5, columnspan=2)
|
|
tkinter.Label(self.root, text="Port range start").grid(column=0, row=2, sticky="e")
|
|
tkinter.Label(self.root, text="Port range end").grid(column=0, row=3, sticky="e")
|
|
entry1 = tkinter.Entry(self.root)
|
|
entry1.grid(column=1, row=2, sticky="ew")
|
|
entry2 = tkinter.Entry(self.root)
|
|
entry2.grid(column=1, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text="Save", command=lambda: validate_and_proceed((btn1, btnVar1), (btn2, btnVar2), (btn3, btnVar3), entry1, entry2, self.firewall.rules['special_ports'], lambda: (self.firewall.save_rules(), self.port_screen()))
|
|
).grid(column=0, row=6, columnspan=2, sticky="e")
|
|
btn2.select()
|
|
btn3.select()
|
|
|
|
def vlan_screen(self):
|
|
self.clear_screen()
|
|
vlans = self.firewall.rules['static_vlans']
|
|
tkinter.Label(self.root, text="vLAN list"
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tkinter.Button(self.root, text="<", command=self.menu_screen
|
|
).grid(column=0, row=0, columnspan=2, sticky="nw")
|
|
tree = tkinter.ttk.Treeview(
|
|
self.root, show='headings', columns=('id', 'nor'))
|
|
tree.heading('id', text='ID')
|
|
tree.heading('nor', text='Number of Rules')
|
|
tree.grid(column=0, row=2, columnspan=2, sticky="ew")
|
|
id2pos = dict()
|
|
pos2id = dict()
|
|
for idno, lst in vlans.items():
|
|
if idno <= 0:
|
|
continue
|
|
iid = tree.insert('', 'end', values=(idno, len(lst)))
|
|
id2pos[iid] = idno
|
|
pos2id[idno] = iid
|
|
def open4edit(ids):
|
|
if len(ids)>0:
|
|
self.blacklist_screen(ids[0], 'vLAN #%d' % ids[0], self.vlan_screen)
|
|
tkinter.Button(self.root, text='Edit',
|
|
command=lambda: open4edit([id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])])
|
|
).grid(column=0, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Create',
|
|
command=self._addvlanid
|
|
).grid(column=1, row=3, sticky="ew")
|
|
tkinter.Label(self.root, text="vLAN list"
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tkinter.Button(self.root, text="<", command=self.menu_screen
|
|
).grid(column=0, row=0, columnspan=2, sticky="nw")
|
|
|
|
def proto_screen(self, selected=()):
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Protocols"
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tkinter.Button(self.root, text="<", command=self.menu_screen
|
|
).grid(column=0, row=0, columnspan=2, sticky="nw")
|
|
tree = tkinter.ttk.Treeview(
|
|
self.root, show='headings', columns=('seq', 'protocol', 'action'))
|
|
tree.heading('seq', text='#')
|
|
tree.heading('action', text='Action')
|
|
tree.heading('protocol', text='Protocol')
|
|
tree.grid(column=0, row=2, columnspan=2, sticky="ew")
|
|
self.firewall.rules['special_protocols'] = self.firewall.rules.get('special_protocols', list())
|
|
rules = self.firewall.rules['special_protocols']
|
|
id2pos = dict()
|
|
pos2id = dict()
|
|
for ndx, item in enumerate(rules):
|
|
iid = tree.insert('', 'end', values=(ndx+1, item['protocol'].upper(), 'Allow' if item['action'] else 'Drop'))
|
|
id2pos[iid] = ndx
|
|
pos2id[ndx] = iid
|
|
for ndx in range(len(rules)):
|
|
if ndx in selected:
|
|
tree.selection_add(pos2id[ndx])
|
|
tkinter.Button(self.root, text='Move up',
|
|
command=lambda: (self.proto_screen([i-1 for i in self._moveup(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=0, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Move down',
|
|
command=lambda: (self.proto_screen([i+1 for i in self._movedown(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=1, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Add',
|
|
command=self._addprotorule
|
|
).grid(column=0, row=4, sticky="ew")
|
|
tkinter.Button(self.root, text='Remove', command=lambda: (
|
|
[rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])],
|
|
self.firewall.save_rules(), self.proto_screen()
|
|
)
|
|
).grid(column=1, row=4, sticky="ew")
|
|
|
|
def port_screen(self, selected=()):
|
|
self.clear_screen()
|
|
tkinter.Label(self.root, text="Ports"
|
|
).grid(column=0, row=0, columnspan=2)
|
|
tkinter.Button(self.root, text="<", command=self.menu_screen
|
|
).grid(column=0, row=0, columnspan=2, sticky="nw")
|
|
tree = tkinter.ttk.Treeview(
|
|
self.root, show='headings', columns=('seq', 'portrange', 'srcdst', 'action'))
|
|
tree.heading('seq', text='#')
|
|
tree.heading('action', text='Action')
|
|
tree.heading('srcdst', text='Source/Destination')
|
|
tree.heading('portrange', text='Port range')
|
|
tree.grid(column=0, row=2, columnspan=2, sticky="ew")
|
|
self.firewall.rules['special_ports'] = self.firewall.rules.get('special_ports', list())
|
|
rules = self.firewall.rules['special_ports']
|
|
id2pos = dict()
|
|
pos2id = dict()
|
|
for ndx, item in enumerate(rules):
|
|
iid = tree.insert('', 'end', values=(
|
|
ndx+1,
|
|
str(item['portrange_start']) if item['portrange_start']==item['portrange_stop'] else "%d .. %d"%(item['portrange_start'], item['portrange_stop']),
|
|
'Both source and destination' if item['src_matters'] and item['dst_matters'] else ('Source' if item['src_matters'] else 'Destination'),
|
|
'Allow' if item['action'] else 'Drop'
|
|
))
|
|
id2pos[iid] = ndx
|
|
pos2id[ndx] = iid
|
|
for ndx in range(len(rules)):
|
|
if ndx in selected:
|
|
tree.selection_add(pos2id[ndx])
|
|
tkinter.Button(self.root, text='Move up',
|
|
command=lambda: (self.port_screen([i-1 for i in self._moveup(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=0, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Move down',
|
|
command=lambda: (self.port_screen([i+1 for i in self._movedown(
|
|
rules, [id2pos[sel] for sel in (tree.selection() if tree.selection() is not None else [])]
|
|
)]), self.firewall.save_rules())
|
|
).grid(column=1, row=3, sticky="ew")
|
|
tkinter.Button(self.root, text='Add',
|
|
command=self._addportrule
|
|
).grid(column=0, row=4, sticky="ew")
|
|
tkinter.Button(self.root, text='Remove', command=lambda: (
|
|
[rules.pop(id2pos[sel]) for sel in (tree.selection() if tree.selection() is not None else [])],
|
|
self.firewall.save_rules(), self.port_screen()
|
|
)
|
|
).grid(column=1, row=4, sticky="ew")
|
|
|
|
# Copyright (C) 2011 Nippon Telegraph and Telephone Corporation.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
class SimpleSwitch13(app_manager.RyuApp):
|
|
OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION]
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(SimpleSwitch13, self).__init__(*args, **kwargs)
|
|
self.mac_to_port = {}
|
|
self.firewalls = dict()
|
|
|
|
@set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER)
|
|
def switch_features_handler(self, ev):
|
|
datapath = ev.msg.datapath
|
|
ofproto = datapath.ofproto
|
|
parser = datapath.ofproto_parser
|
|
|
|
# install table-miss flow entry
|
|
#
|
|
# We specify NO BUFFER to max_len of the output action due to
|
|
# OVS bug. At this moment, if we specify a lesser number, e.g.,
|
|
# 128, OVS will send Packet-In with invalid buffer_id and
|
|
# truncated packet data. In that case, we cannot output packets
|
|
# correctly. The bug has been fixed in OVS v2.1.0.
|
|
match = parser.OFPMatch()
|
|
actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER,
|
|
ofproto.OFPCML_NO_BUFFER)]
|
|
self.add_flow(datapath, 0, match, actions)
|
|
|
|
def add_flow(self, datapath, priority, match, actions, buffer_id=None, idle_timeout=0, hard_timeout=0):
|
|
ofproto = datapath.ofproto
|
|
parser = datapath.ofproto_parser
|
|
|
|
inst = [parser.OFPInstructionActions(ofproto.OFPIT_APPLY_ACTIONS,
|
|
actions)]
|
|
if buffer_id:
|
|
mod = parser.OFPFlowMod(datapath=datapath, buffer_id=buffer_id,
|
|
priority=priority, match=match,
|
|
instructions=inst, idle_timeout=idle_timeout, hard_timeout=hard_timeout)
|
|
else:
|
|
mod = parser.OFPFlowMod(datapath=datapath, priority=priority,
|
|
match=match, instructions=inst, idle_timeout=idle_timeout, hard_timeout=hard_timeout)
|
|
datapath.send_msg(mod)
|
|
|
|
@set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
|
|
def _packet_in_handler(self, ev):
|
|
# If you hit this you might want to increase
|
|
# the "miss_send_length" of your switch
|
|
if ev.msg.msg_len < ev.msg.total_len:
|
|
self.logger.debug("packet truncated: only %s of %s bytes",
|
|
ev.msg.msg_len, ev.msg.total_len)
|
|
msg = ev.msg
|
|
datapath = msg.datapath
|
|
ofproto = datapath.ofproto
|
|
parser = datapath.ofproto_parser
|
|
in_port = msg.match['in_port']
|
|
|
|
pkt = packet.Packet(msg.data)
|
|
eth = pkt.get_protocols(ethernet.ethernet)[0]
|
|
|
|
if eth.ethertype == ether_types.ETH_TYPE_LLDP:
|
|
# ignore lldp packet
|
|
return
|
|
dst = eth.dst
|
|
src = eth.src
|
|
|
|
dpid = datapath.id
|
|
self.mac_to_port.setdefault(dpid, {})
|
|
|
|
self.logger.info("packet in %s %s %s %s", dpid, src, dst, in_port)
|
|
|
|
# learn a mac address to avoid FLOOD next time.
|
|
self.mac_to_port[dpid][src] = in_port
|
|
|
|
if dst in self.mac_to_port[dpid]:
|
|
out_port = self.mac_to_port[dpid][dst]
|
|
else:
|
|
out_port = ofproto.OFPP_FLOOD
|
|
|
|
actions = [parser.OFPActionOutput(out_port)]
|
|
|
|
# run firewall on it
|
|
if datapath.id not in self.firewalls:
|
|
self.firewalls[datapath.id] = Firewall(datapath)
|
|
actions, save_in_flow_table, *_ = [
|
|
*list(self.firewalls[datapath.id].actions_for_new_flow(
|
|
event=ev,
|
|
message=msg,
|
|
datapath=datapath,
|
|
ofproto=ofproto,
|
|
eth=eth,
|
|
parser=parser,
|
|
allow_action=actions
|
|
)),
|
|
True
|
|
]
|
|
|
|
if save_in_flow_table:
|
|
# install a flow to avoid packet_in next time
|
|
if out_port != ofproto.OFPP_FLOOD:
|
|
match = parser.OFPMatch(in_port=in_port, eth_dst=dst)
|
|
# verify if we have a valid buffer_id, if yes avoid to send both
|
|
# flow_mod & packet_out
|
|
if msg.buffer_id != ofproto.OFP_NO_BUFFER:
|
|
self.add_flow(datapath, 1, match, actions, msg.buffer_id, idle_timeout=30, hard_timeout=90)
|
|
return
|
|
else:
|
|
self.add_flow(datapath, 1, match, actions, idle_timeout=30, hard_timeout=90)
|
|
data = None
|
|
if msg.buffer_id == ofproto.OFP_NO_BUFFER:
|
|
data = msg.data
|
|
|
|
out = parser.OFPPacketOut(datapath=datapath, buffer_id=msg.buffer_id,
|
|
in_port=in_port, actions=actions, data=data)
|
|
datapath.send_msg(out)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
Firewall(None).manage_from_gui()
|