Complete first alpha version

This commit is contained in:
Daniel
2018-12-12 19:18:23 +01:00
parent 8c11a35590
commit f35872ec51
36 changed files with 624 additions and 293 deletions

6
README.md Normal file
View File

@@ -0,0 +1,6 @@
# Portmaster
## Dependencies
- Arch/based: libnetfilter_queue
x

View File

@@ -38,7 +38,7 @@ var (
) )
func init() { func init() {
modules.Register("firewall", prep, start, stop, "global", "network", "nameserver") modules.Register("firewall", prep, start, stop, "global", "network", "nameserver", "profile")
} }
func prep() (err error) { func prep() (err error) {
@@ -112,12 +112,6 @@ func handlePacket(pkt packet.Packet) {
return return
} }
// allow anything that goes to a tunnel entrypoint
if pkt.IsOutbound() && (pkt.GetIPHeader().Dst.Equal(tunnelEntry4) || pkt.GetIPHeader().Dst.Equal(tunnelEntry6)) {
pkt.PermanentAccept()
return
}
// log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID()) // log.Debugf("firewall: pkt %s has ID %s", pkt, pkt.GetConnectionID())
// use this to time how long it takes process packet // use this to time how long it takes process packet
@@ -147,7 +141,7 @@ func handlePacket(pkt packet.Packet) {
link.HandlePacket(pkt) link.HandlePacket(pkt)
return return
} }
verdict(pkt, link.Verdict) verdict(pkt, link.GetVerdict())
} }
@@ -156,42 +150,45 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
// get Connection // get Connection
connection, err := network.GetConnectionByFirstPacket(pkt) connection, err := network.GetConnectionByFirstPacket(pkt)
if err != nil { if err != nil {
link.Lock()
if err != process.ErrConnectionNotFound { if err != process.ErrConnectionNotFound {
log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err) log.Warningf("firewall: could not find process of packet (dropping link %s): %s", pkt.String(), err)
link.AddReason(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err)) link.Deny(fmt.Sprintf("could not find process or it does not exist (unsolicited packet): %s", err))
} else { } else {
log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err) log.Warningf("firewall: internal error finding process of packet (dropping link %s): %s", pkt.String(), err)
link.AddReason(fmt.Sprintf("internal error finding process: %s", err)) link.Deny(fmt.Sprintf("internal error finding process: %s", err))
} }
link.Unlock()
if pkt.IsInbound() { if pkt.IsInbound() {
network.UnknownIncomingConnection.AddLink(link) network.UnknownIncomingConnection.AddLink(link)
} else { } else {
network.UnknownDirectConnection.AddLink(link) network.UnknownDirectConnection.AddLink(link)
} }
verdict(pkt, link.Verdict)
return
}
// reroute dns requests to nameserver verdict(pkt, link.GetVerdict())
if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 { link.StopFirewallHandler()
pkt.RerouteToNameserver()
return return
} }
// add new Link to Connection (and save both) // add new Link to Connection (and save both)
connection.AddLink(link) connection.AddLink(link)
// reroute dns requests to nameserver
if connection.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.GetTCPUDPHeader() != nil && !pkt.GetIPHeader().Dst.Equal(localhost) && pkt.GetTCPUDPHeader().DstPort == 53 {
link.RerouteToNameserver()
verdict(pkt, link.GetVerdict())
link.StopFirewallHandler()
return
}
// make a decision if not made already // make a decision if not made already
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
DecideOnConnection(connection, pkt) DecideOnConnection(connection, pkt)
} }
if connection.Verdict != network.CANTSAY { if connection.GetVerdict() == network.ACCEPT {
link.UpdateVerdict(connection.Verdict)
} else {
DecideOnLink(connection, link, pkt) DecideOnLink(connection, link, pkt)
} else {
link.UpdateVerdict(connection.GetVerdict())
} }
// log decision // log decision
@@ -205,7 +202,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
// // tunnel link, but also inspect (after reroute) // // tunnel link, but also inspect (after reroute)
// link.Tunneled = true // link.Tunneled = true
// link.SetFirewallHandler(inspectThenVerdict) // link.SetFirewallHandler(inspectThenVerdict)
// verdict(pkt, link.Verdict) // verdict(pkt, link.GetVerdict())
// case port17Active: // case port17Active:
// // tunnel link, don't inspect // // tunnel link, don't inspect
// link.Tunneled = true // link.Tunneled = true
@@ -216,7 +213,7 @@ func initialHandler(pkt packet.Packet, link *network.Link) {
inspectThenVerdict(pkt, link) inspectThenVerdict(pkt, link)
default: default:
link.StopFirewallHandler() link.StopFirewallHandler()
verdict(pkt, link.Verdict) verdict(pkt, link.GetVerdict())
} }
} }
@@ -225,10 +222,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
pktVerdict, continueInspection := inspection.RunInspectors(pkt, link) pktVerdict, continueInspection := inspection.RunInspectors(pkt, link)
if continueInspection { if continueInspection {
// do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link // do not allow to circumvent link decision: e.g. to ACCEPT packets from a DROP-ed link
if pktVerdict > link.Verdict { linkVerdict := link.GetVerdict()
if pktVerdict > linkVerdict {
verdict(pkt, pktVerdict) verdict(pkt, pktVerdict)
} else { } else {
verdict(pkt, link.Verdict) verdict(pkt, linkVerdict)
} }
return return
} }
@@ -236,9 +234,11 @@ func inspectThenVerdict(pkt packet.Packet, link *network.Link) {
// we are done with inspecting // we are done with inspecting
link.StopFirewallHandler() link.StopFirewallHandler()
link.Lock()
defer link.Unlock()
link.VerdictPermanent = permanentVerdicts() link.VerdictPermanent = permanentVerdicts()
if link.VerdictPermanent { if link.VerdictPermanent {
link.Save() go link.Save()
permanentVerdict(pkt, link.Verdict) permanentVerdict(pkt, link.Verdict)
} else { } else {
verdict(pkt, link.Verdict) verdict(pkt, link.Verdict)
@@ -259,6 +259,12 @@ func permanentVerdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
pkt.PermanentDrop() pkt.PermanentDrop()
return return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
} }
pkt.Drop() pkt.Drop()
} }
@@ -277,6 +283,12 @@ func verdict(pkt packet.Packet, action network.Verdict) {
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
pkt.Drop() pkt.Drop()
return return
case network.RerouteToNameserver:
pkt.RerouteToNameserver()
return
case network.RerouteToTunnel:
pkt.RerouteToTunnel()
return
} }
pkt.Drop() pkt.Drop()
} }
@@ -295,18 +307,22 @@ func verdict(pkt packet.Packet, action network.Verdict) {
// } // }
func logInitialVerdict(link *network.Link) { func logInitialVerdict(link *network.Link) {
// switch link.Verdict { // switch link.GetVerdict() {
// case network.ACCEPT: // case network.ACCEPT:
// log.Infof("firewall: accepting new link: %s", link.String()) // log.Infof("firewall: accepting new link: %s", link.String())
// case network.BLOCK: // case network.BLOCK:
// log.Infof("firewall: blocking new link: %s", link.String()) // log.Infof("firewall: blocking new link: %s", link.String())
// case network.DROP: // case network.DROP:
// log.Infof("firewall: dropping new link: %s", link.String()) // log.Infof("firewall: dropping new link: %s", link.String())
// case network.RerouteToNameserver:
// log.Infof("firewall: rerouting new link to nameserver: %s", link.String())
// case network.RerouteToTunnel:
// log.Infof("firewall: rerouting new link to tunnel: %s", link.String())
// } // }
} }
func logChangedVerdict(link *network.Link) { func logChangedVerdict(link *network.Link) {
// switch link.Verdict { // switch link.GetVerdict() {
// case network.ACCEPT: // case network.ACCEPT:
// log.Infof("firewall: change! - now accepting link: %s", link.String()) // log.Infof("firewall: change! - now accepting link: %s", link.String())
// case network.BLOCK: // case network.BLOCK:

View File

@@ -8,6 +8,7 @@ import (
"github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/packet" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/profile"
"github.com/Safing/portmaster/status" "github.com/Safing/portmaster/status"
"github.com/agext/levenshtein" "github.com/agext/levenshtein"
@@ -59,10 +60,10 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string)
permitted, ok := profileSet.CheckDomain(fqdn) permitted, ok := profileSet.CheckDomain(fqdn)
if ok { if ok {
if permitted { if permitted {
log.Infof("firewall: accepting connection %s, domain is whitelisted", connection, domainElement, processElement) log.Infof("firewall: accepting connection %s, domain is whitelisted", connection)
connection.Accept("domain is whitelisted") connection.Accept("domain is whitelisted")
} else { } else {
log.Infof("firewall: denying connection %s, domain is blacklisted", connection, domainElement, processElement) log.Infof("firewall: denying connection %s, domain is blacklisted", connection)
connection.Deny("domain is blacklisted") connection.Deny("domain is blacklisted")
} }
return return
@@ -70,7 +71,7 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string)
switch profileSet.GetProfileMode() { switch profileSet.GetProfileMode() {
case profile.Whitelist: case profile.Whitelist:
log.Infof("firewall: denying connection %s, domain is not whitelisted", connection, domainElement, processElement) log.Infof("firewall: denying connection %s, domain is not whitelisted", connection)
connection.Deny("domain is not whitelisted") connection.Deny("domain is not whitelisted")
case profile.Prompt: case profile.Prompt:
@@ -97,9 +98,9 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string)
break matchLoop break matchLoop
} }
} }
if levenshtein.Match(domainElement, profile.Name, nil) > 0.5 { if levenshtein.Match(domainElement, profileSet.UserProfile().Name, nil) > 0.5 {
matched = true matched = true
processElement = profile.Name processElement = profileSet.UserProfile().Name
break matchLoop break matchLoop
} }
if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 { if levenshtein.Match(domainElement, connection.Process().Name, nil) > 0.5 {
@@ -107,6 +108,11 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string)
processElement = connection.Process().Name processElement = connection.Process().Name
break matchLoop break matchLoop
} }
if levenshtein.Match(domainElement, connection.Process().ExecName, nil) > 0.5 {
matched = true
processElement = connection.Process().ExecName
break matchLoop
}
} }
if matched { if matched {
@@ -115,15 +121,15 @@ func DecideOnConnectionBeforeIntel(connection *network.Connection, fqdn string)
} }
} }
if connection.Verdict != network.ACCEPT { if connection.GetVerdict() != network.ACCEPT {
// TODO // TODO
log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection, domainElement, processElement) log.Infof("firewall: accepting connection %s, domain permitted (prompting is not yet implemented)", connection)
connection.Accept("domain permitted (prompting is not yet implemented)") connection.Accept("domain permitted (prompting is not yet implemented)")
} }
case profile.Blacklist: case profile.Blacklist:
log.Infof("firewall: denying connection %s, domain is not blacklisted", connection, domainElement, processElement) log.Infof("firewall: accepting connection %s, domain is not blacklisted", connection)
connection.Deny("domain is not blacklisted") connection.Accept("domain is not blacklisted")
} }
} }
@@ -175,8 +181,8 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
} }
// check if there is a profile // check if there is a profile
profileSet := connection.Process().ProfileSet profileSet := connection.Process().ProfileSet()
if profile == nil { if profileSet == nil {
log.Errorf("firewall: denying connection %s, no profile set", connection) log.Errorf("firewall: denying connection %s, no profile set", connection)
connection.Deny("no profile") connection.Deny("no profile")
return return
@@ -185,17 +191,17 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// check connection type // check connection type
switch connection.Domain { switch connection.Domain {
case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid: case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid:
if !profileSet.CheckFlag(profile.Service) { if !profileSet.CheckFlag(profile.Service) {
log.Infof("firewall: denying connection %s, not a service", connection) log.Infof("firewall: denying connection %s, not a service", connection)
if connection.Domain == IncomingHost { if connection.Domain == network.IncomingHost {
connection.Block("not a service") connection.Block("not a service")
} else { } else {
connection.Drop("not a service") connection.Drop("not a service")
} }
return return
} }
case PeerLAN, PeerInternet, PeerInvalid: // Important: PeerHost is and should be missing! case network.PeerLAN, network.PeerInternet, network.PeerInvalid: // Important: PeerHost is and should be missing!
if !profileSet.CheckFlag(profile.PeerToPeer) { if !profileSet.CheckFlag(profile.PeerToPeer) {
log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection) log.Infof("firewall: denying connection %s, peer to peer connections (to an IP) not allowed", connection)
connection.Deny("peer to peer connections (to an IP) not allowed") connection.Deny("peer to peer connections (to an IP) not allowed")
@@ -205,54 +211,54 @@ func DecideOnConnection(connection *network.Connection, pkt packet.Packet) {
// check network scope // check network scope
switch connection.Domain { switch connection.Domain {
case IncomingHost: case network.IncomingHost:
if !profileSet.CheckFlag(profile.Localhost) { if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, serving localhost not allowed", connection) log.Infof("firewall: denying connection %s, serving localhost not allowed", connection)
connection.Block("serving localhost not allowed") connection.Block("serving localhost not allowed")
return return
} }
case IncomingLAN: case network.IncomingLAN:
if !profileSet.CheckFlag(profile.LAN) { if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, serving LAN not allowed", connection) log.Infof("firewall: denying connection %s, serving LAN not allowed", connection)
connection.Deny("serving LAN not allowed") connection.Deny("serving LAN not allowed")
return return
} }
case IncomingInternet: case network.IncomingInternet:
if !profileSet.CheckFlag(profile.Internet) { if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, serving Internet not allowed", connection) log.Infof("firewall: denying connection %s, serving Internet not allowed", connection)
connection.Deny("serving Internet not allowed") connection.Deny("serving Internet not allowed")
return return
} }
case IncomingInvalid: case network.IncomingInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection) log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Drop("invalid IP address") connection.Drop("invalid IP address")
return return
case PeerHost: case network.PeerHost:
if !profileSet.CheckFlag(profile.Localhost) { if !profileSet.CheckFlag(profile.Localhost) {
log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection) log.Infof("firewall: denying connection %s, accessing localhost not allowed", connection)
connection.Block("accessing localhost not allowed") connection.Block("accessing localhost not allowed")
return return
} }
case PeerLAN: case network.PeerLAN:
if !profileSet.CheckFlag(profile.LAN) { if !profileSet.CheckFlag(profile.LAN) {
log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection) log.Infof("firewall: denying connection %s, accessing the LAN not allowed", connection)
connection.Deny("accessing the LAN not allowed") connection.Deny("accessing the LAN not allowed")
return return
} }
case PeerInternet: case network.PeerInternet:
if !profileSet.CheckFlag(profile.Internet) { if !profileSet.CheckFlag(profile.Internet) {
log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection) log.Infof("firewall: denying connection %s, accessing the Internet not allowed", connection)
connection.Deny("accessing the Internet not allowed") connection.Deny("accessing the Internet not allowed")
return return
} }
case PeerInvalid: case network.PeerInvalid:
log.Infof("firewall: denying connection %s, invalid IP address", connection) log.Infof("firewall: denying connection %s, invalid IP address", connection)
connection.Deny("invalid IP address") connection.Deny("invalid IP address")
return return
} }
log.Infof("firewall: accepting connection %s", connection) log.Infof("firewall: accepting connection %s", connection)
connection.Accept() connection.Accept("")
} }
// DecideOnLink makes a decision about a link with the first packet. // DecideOnLink makes a decision about a link with the first packet.
@@ -264,8 +270,8 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet
// Profile.ListenPorts // Profile.ListenPorts
// check if there is a profile // check if there is a profile
profileSet := connection.Process().ProfileSet profileSet := connection.Process().ProfileSet()
if profile == nil { if profileSet == nil {
log.Infof("firewall: no profile, denying %s", link) log.Infof("firewall: no profile, denying %s", link)
link.Block("no profile") link.Block("no profile")
return return
@@ -274,20 +280,20 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet
// get remote Port // get remote Port
protocol := pkt.GetIPHeader().Protocol protocol := pkt.GetIPHeader().Protocol
var remotePort uint16 var dstPort uint16
tcpUdpHeader := pkt.GetTCPUDPHeader() tcpUDPHeader := pkt.GetTCPUDPHeader()
if tcpUdpHeader != nil { if tcpUDPHeader != nil {
remotePort = tcpUdpHeader.DstPort dstPort = tcpUDPHeader.DstPort
} }
// check port list // check port list
permitted, ok := profileSet.CheckPort(connection.Direction, protocol, remotePort) permitted, ok := profileSet.CheckPort(connection.Direction, uint8(protocol), dstPort)
if ok { if ok {
if permitted { if permitted {
log.Infof("firewall: accepting link %s", link) log.Infof("firewall: accepting link %s", link)
link.Accept("port whitelisted") link.Accept("port whitelisted")
} else { } else {
log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) log.Infof("firewall: denying link %s: port %d is blacklisted", link, dstPort)
link.Deny("port blacklisted") link.Deny("port blacklisted")
} }
return return
@@ -295,14 +301,17 @@ func DecideOnLink(connection *network.Connection, link *network.Link, pkt packet
switch profileSet.GetProfileMode() { switch profileSet.GetProfileMode() {
case profile.Whitelist: case profile.Whitelist:
log.Infof("firewall: denying link %s: port %d is not whitelisted", link, remotePort) log.Infof("firewall: denying link %s: port %d is not whitelisted", link, dstPort)
link.Deny("port is not whitelisted") link.Deny("port is not whitelisted")
return
case profile.Prompt: case profile.Prompt:
log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) log.Infof("firewall: accepting link %s: port %d is blacklisted", link, dstPort)
link.Accept("port permitted (prompting is not yet implemented)") link.Accept("port permitted (prompting is not yet implemented)")
return
case profile.Blacklist: case profile.Blacklist:
log.Infof("firewall: denying link %s: port %d is blacklisted", link, remotePort) log.Infof("firewall: accepting link %s: port %d is not blacklisted", link, dstPort)
link.Deny("port is not blacklisted") link.Accept("port is not blacklisted")
return
} }
log.Infof("firewall: accepting link %s", link) log.Infof("firewall: accepting link %s", link)

View File

@@ -7,10 +7,11 @@ import (
// module dependencies // module dependencies
_ "github.com/Safing/portbase/database/dbmodule" _ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger" _ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/status"
) )
func init() { func init() {
modules.Register("global", nil, start, nil, "database") modules.Register("global", nil, start, nil, "database", "status")
} }
func start() error { func start() error {

View File

@@ -59,10 +59,10 @@ func GetIPInfo(ip string) (*IPInfo, error) {
// AddDomain adds a domain to the list and reports back if it was added, or was already present. // AddDomain adds a domain to the list and reports back if it was added, or was already present.
func (ipi *IPInfo) AddDomain(domain string) (added bool) { func (ipi *IPInfo) AddDomain(domain string) (added bool) {
ipi.Lock()
defer ipi.Unlock()
if !utils.StringInSlice(ipi.Domains, domain) { if !utils.StringInSlice(ipi.Domains, domain) {
newDomains := make([]string, 1, len(ipi.Domains)+1) ipi.Domains = append([]string{domain}, ipi.Domains...)
newDomains[0] = domain
ipi.Domains = append(newDomains, ipi.Domains...)
return true return true
} }
return false return false
@@ -70,11 +70,22 @@ func (ipi *IPInfo) AddDomain(domain string) (added bool) {
// Save saves the IPInfo record to the database. // Save saves the IPInfo record to the database.
func (ipi *IPInfo) Save() error { func (ipi *IPInfo) Save() error {
ipi.SetKey(makeIPInfoKey(ipi.IP)) ipi.Lock()
return ipInfoDatabase.PutNew(ipi) if !ipi.KeyIsSet() {
ipi.SetKey(makeIPInfoKey(ipi.IP))
}
ipi.Unlock()
return ipInfoDatabase.Put(ipi)
} }
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or " // FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) FmtDomains() string { func (ipi *IPInfo) FmtDomains() string {
return strings.Join(ipi.Domains, " or ") return strings.Join(ipi.Domains, " or ")
} }
// FmtDomains returns a string consisting of the domains that have seen to use this IP, joined by " or "
func (ipi *IPInfo) String() string {
ipi.Lock()
defer ipi.Unlock()
return fmt.Sprintf("<IPInfo[%s] %s: %s", ipi.Key(), ipi.IP, ipi.FmtDomains())
}

25
intel/ipinfo_test.go Normal file
View File

@@ -0,0 +1,25 @@
package intel
import "testing"
func testDomains(t *testing.T, ipi *IPInfo, expectedDomains string) {
if ipi.FmtDomains() != expectedDomains {
t.Errorf("unexpected domains '%s', expected '%s'", ipi.FmtDomains(), expectedDomains)
}
}
func TestIPInfo(t *testing.T) {
ipi := &IPInfo{
IP: "1.2.3.4",
Domains: []string{"example.com.", "sub.example.com."},
}
testDomains(t, ipi, "example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("sub.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
ipi.AddDomain("added.example.com.")
testDomains(t, ipi, "added.example.com. or example.com. or sub.example.com.")
}

View File

@@ -152,7 +152,7 @@ configuredServersLoop:
ServerType: parts[0], ServerType: parts[0],
ServerAddress: parts[1], ServerAddress: parts[1],
ServerIP: ip, ServerIP: ip,
ServerIPScope: netutils.ClassifyAddress(ip), ServerIPScope: netutils.ClassifyIP(ip),
ServerPort: port, ServerPort: port,
LastFail: &lastFail, LastFail: &lastFail,
Source: "config", Source: "config",
@@ -207,7 +207,7 @@ assignedServersLoop:
ServerType: "dns", ServerType: "dns",
ServerAddress: urlFormatAddress(nameserver.IP, 53), ServerAddress: urlFormatAddress(nameserver.IP, 53),
ServerIP: nameserver.IP, ServerIP: nameserver.IP,
ServerIPScope: netutils.ClassifyAddress(nameserver.IP), ServerIPScope: netutils.ClassifyIP(nameserver.IP),
ServerPort: 53, ServerPort: 53,
LastFail: &lastFail, LastFail: &lastFail,
Source: "dhcp", Source: "dhcp",

View File

@@ -216,9 +216,9 @@ entryLoop:
classification = -1 classification = -1
switch v := rr.(type) { switch v := rr.(type) {
case *dns.A: case *dns.A:
classification = netutils.ClassifyAddress(v.A) classification = netutils.ClassifyIP(v.A)
case *dns.AAAA: case *dns.AAAA:
classification = netutils.ClassifyAddress(v.AAAA) classification = netutils.ClassifyIP(v.AAAA)
} }
if classification >= 0 { if classification >= 0 {

32
main.go
View File

@@ -1,10 +1,13 @@
package main package main
import ( import (
"flag"
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"runtime/pprof"
"syscall" "syscall"
"time"
"github.com/Safing/portbase/info" "github.com/Safing/portbase/info"
"github.com/Safing/portbase/log" "github.com/Safing/portbase/log"
@@ -12,11 +15,21 @@ import (
// include packages here // include packages here
_ "github.com/Safing/portbase/api"
_ "github.com/Safing/portbase/database/dbmodule" _ "github.com/Safing/portbase/database/dbmodule"
_ "github.com/Safing/portbase/database/storage/badger" _ "github.com/Safing/portbase/database/storage/badger"
_ "github.com/Safing/portmaster/firewall" _ "github.com/Safing/portmaster/firewall"
_ "github.com/Safing/portmaster/nameserver"
) )
var (
printStackOnExit bool
)
func init() {
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
}
func main() { func main() {
// Set Info // Set Info
@@ -51,7 +64,24 @@ func main() {
case <-signalCh: case <-signalCh:
fmt.Println(" <INTERRUPT>") fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.") log.Warning("main: program was interrupted, shutting down.")
modules.Shutdown()
if printStackOnExit {
fmt.Println("=== PRINTING STACK ===")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
fmt.Println("=== END STACK ===")
}
go func() {
modules.Shutdown()
os.Exit(0)
}()
go func() {
time.Sleep(3 * time.Second)
fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 2)
os.Exit(1)
}()
case <-modules.ShuttingDown(): case <-modules.ShuttingDown():
} }

View File

@@ -12,10 +12,10 @@ import (
"github.com/Safing/portbase/modules" "github.com/Safing/portbase/modules"
"github.com/Safing/portmaster/analytics/algs" "github.com/Safing/portmaster/analytics/algs"
"github.com/Safing/portmaster/firewall"
"github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network"
"github.com/Safing/portmaster/network/netutils" "github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/firewall"
) )
var ( var (
@@ -138,19 +138,19 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn) // log.Tracef("nameserver: took %s to get connection/process of %s request", time.Now().Sub(timed).String(), fqdn)
// check profile before we even get intel and rr // check profile before we even get intel and rr
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
// start = time.Now() // start = time.Now()
firewall.DecideOnConnectionBeforeIntel(connection, fqdn) firewall.DecideOnConnectionBeforeIntel(connection, fqdn)
// log.Tracef("nameserver: took %s to make decision", time.Since(start)) // log.Tracef("nameserver: took %s to make decision", time.Since(start))
} }
if connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { if connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query) nxDomain(w, query)
return return
} }
// get intel and RRs // get intel and RRs
// start = time.Now() // start = time.Now()
domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().Profile.SecurityLevel) domainIntel, rrCache := intel.GetIntelAndRRs(fqdn, qtype, connection.Process().ProfileSet().SecurityLevel())
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil { if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains // TODO: analyze nxdomain requests, malware could be trying DGA-domains
@@ -160,14 +160,16 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
} }
// set intel // set intel
connection.Lock()
connection.Intel = domainIntel connection.Intel = domainIntel
connection.Unlock()
connection.Save() connection.Save()
// do a full check with intel // do a full check with intel
if connection.Verdict == network.UNDECIDED { if connection.GetVerdict() == network.UNDECIDED {
rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache) rrCache = firewall.DecideOnConnectionAfterIntel(connection, fqdn, rrCache)
} }
if rrCache == nil || connection.Verdict == network.BLOCK || connection.Verdict == network.DROP { if rrCache == nil || connection.GetVerdict() == network.BLOCK || connection.GetVerdict() == network.DROP {
nxDomain(w, query) nxDomain(w, query)
return return
} }
@@ -179,24 +181,27 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
ipInfo, err := intel.GetIPInfo(v.A.String()) ipInfo, err := intel.GetIPInfo(v.A.String())
if err != nil { if err != nil {
ipInfo = &intel.IPInfo{ ipInfo = &intel.IPInfo{
IP: v.A.String(),
Domains: []string{fqdn}, Domains: []string{fqdn},
} }
ipInfo.Create(v.A.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save() ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
} }
case *dns.AAAA: case *dns.AAAA:
ipInfo, err := intel.GetIPInfo(v.AAAA.String()) ipInfo, err := intel.GetIPInfo(v.AAAA.String())
if err != nil { if err != nil {
ipInfo = &intel.IPInfo{ ipInfo = &intel.IPInfo{
IP: IP: v.AAAA.String(),
Domains: []string{fqdn}, Domains: []string{fqdn},
} }
ipInfo.Create(v.AAAA.String())
} else {
ipInfo.Domains = append(ipInfo.Domains, fqdn)
ipInfo.Save() ipInfo.Save()
} else {
if ipInfo.AddDomain(fqdn) {
ipInfo.Save()
}
} }
} }
} }

View File

@@ -17,8 +17,11 @@ var (
func cleaner() { func cleaner() {
for { for {
time.Sleep(cleanerTickDuration) time.Sleep(cleanerTickDuration)
cleanLinks() cleanLinks()
time.Sleep(10 * time.Second)
cleanConnections() cleanConnections()
time.Sleep(10 * time.Second)
cleanProcesses() cleanProcesses()
} }
} }
@@ -26,18 +29,21 @@ func cleaner() {
func cleanLinks() { func cleanLinks() {
activeIDs := process.GetActiveConnectionIDs() activeIDs := process.GetActiveConnectionIDs()
dataLock.Lock()
defer dataLock.Lock()
now := time.Now().Unix() now := time.Now().Unix()
deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix() deleteOlderThan := time.Now().Add(-deadLinksTimeout).Unix()
linksLock.RLock()
defer linksLock.RUnlock()
var found bool var found bool
for key, link := range links { for key, link := range links {
// delete dead links // delete dead links
if link.Ended > 0 && link.Ended < deleteOlderThan { link.Lock()
link.Delete() deleteThis := link.Ended > 0 && link.Ended < deleteOlderThan
link.Unlock()
if deleteThis {
go link.Delete()
continue continue
} }
@@ -53,21 +59,23 @@ func cleanLinks() {
// mark end time // mark end time
if !found { if !found {
link.Ended = now link.Ended = now
link.Save() go link.Save()
} }
} }
} }
func cleanConnections() { func cleanConnections() {
dataLock.Lock() connectionsLock.RLock()
defer dataLock.Lock() defer connectionsLock.RUnlock()
threshold := time.Now().Add(-thresholdDuration).Unix() threshold := time.Now().Add(-thresholdDuration).Unix()
for _, conn := range connections { for _, conn := range connections {
conn.Lock()
if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 { if conn.FirstLinkEstablished < threshold && conn.LinkCount == 0 {
conn.Delete() go conn.Delete()
} }
conn.Unlock()
} }
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/Safing/portbase/database/record" "github.com/Safing/portbase/database/record"
"github.com/Safing/portmaster/intel" "github.com/Safing/portmaster/intel"
"github.com/Safing/portmaster/network/netutils"
"github.com/Safing/portmaster/network/packet" "github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/process" "github.com/Safing/portmaster/process"
) )
@@ -35,17 +36,28 @@ type Connection struct {
// Process returns the process that owns the connection. // Process returns the process that owns the connection.
func (conn *Connection) Process() *process.Process { func (conn *Connection) Process() *process.Process {
conn.Lock()
defer conn.Unlock()
return conn.process return conn.process
} }
// GetVerdict returns the current verdict.
func (conn *Connection) GetVerdict() Verdict {
conn.Lock()
defer conn.Unlock()
return conn.Verdict
}
// Accept accepts the connection and adds the given reason. // Accept accepts the connection and adds the given reason.
func (conn *Link) Accept(reason string) { func (conn *Connection) Accept(reason string) {
conn.AddReason(reason) conn.AddReason(reason)
conn.UpdateVerdict(ACCEPT) conn.UpdateVerdict(ACCEPT)
} }
// Deny blocks or drops the connection depending on the connection direction and adds the given reason. // Deny blocks or drops the connection depending on the connection direction and adds the given reason.
func (conn *Link) Deny(reason string) { func (conn *Connection) Deny(reason string) {
if conn.Direction { if conn.Direction {
conn.Drop(reason) conn.Drop(reason)
} else { } else {
@@ -54,13 +66,13 @@ func (conn *Link) Deny(reason string) {
} }
// Block blocks the connection and adds the given reason. // Block blocks the connection and adds the given reason.
func (conn *Link) Block(reason string) { func (conn *Connection) Block(reason string) {
conn.AddReason(reason) conn.AddReason(reason)
conn.UpdateVerdict(BLOCK) conn.UpdateVerdict(BLOCK)
} }
// Drop drops the connection and adds the given reason. // Drop drops the connection and adds the given reason.
func (conn *Link) Drop(reason string) { func (conn *Connection) Drop(reason string) {
conn.AddReason(reason) conn.AddReason(reason)
conn.UpdateVerdict(DROP) conn.UpdateVerdict(DROP)
} }
@@ -72,7 +84,7 @@ func (conn *Connection) UpdateVerdict(newVerdict Verdict) {
if newVerdict > conn.Verdict { if newVerdict > conn.Verdict {
conn.Verdict = newVerdict conn.Verdict = newVerdict
conn.Save() go conn.Save()
} }
} }
@@ -103,13 +115,13 @@ func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) {
// Incoming // Incoming
if direction { if direction {
switch netutils.ClassifyIP(pkt.GetIPHeader().Src) { switch netutils.ClassifyIP(pkt.GetIPHeader().Src) {
case HostLocal: case netutils.HostLocal:
domain = IncomingHost domain = IncomingHost
case LinkLocal, SiteLocal, LocalMulticast: case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = IncomingLAN domain = IncomingLAN
case Global, GlobalMulticast: case netutils.Global, netutils.GlobalMulticast:
domain = IncomingInternet domain = IncomingInternet
case Invalid: case netutils.Invalid:
domain = IncomingInvalid domain = IncomingInvalid
} }
@@ -135,13 +147,13 @@ func GetConnectionByFirstPacket(pkt packet.Packet) (*Connection, error) {
// if no domain could be found, it must be a direct connection // if no domain could be found, it must be a direct connection
switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) { switch netutils.ClassifyIP(pkt.GetIPHeader().Dst) {
case HostLocal: case netutils.HostLocal:
domain = PeerHost domain = PeerHost
case LinkLocal, SiteLocal, LocalMulticast: case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
domain = PeerLAN domain = PeerLAN
case Global, GlobalMulticast: case netutils.Global, netutils.GlobalMulticast:
domain = PeerInternet domain = PeerInternet
case Invalid: case netutils.Invalid:
domain = PeerInvalid domain = PeerInvalid
} }
@@ -205,8 +217,8 @@ func GetConnectionByDNSRequest(ip net.IP, port uint16, fqdn string) (*Connection
// GetConnection fetches a connection object from the internal storage. // GetConnection fetches a connection object from the internal storage.
func GetConnection(pid int, domain string) (conn *Connection, ok bool) { func GetConnection(pid int, domain string) (conn *Connection, ok bool) {
dataLock.RLock() connectionsLock.RLock()
defer dataLock.RUnlock() defer connectionsLock.RUnlock()
conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)] conn, ok = connections[fmt.Sprintf("%d/%s", pid, domain)]
return return
} }
@@ -217,58 +229,63 @@ func (conn *Connection) makeKey() string {
// Save saves the connection object in the storage and propagates the change. // Save saves the connection object in the storage and propagates the change.
func (conn *Connection) Save() error { func (conn *Connection) Save() error {
conn.Lock()
defer conn.Unlock()
if conn.process == nil { if conn.process == nil {
return errors.New("cannot save connection without process") return errors.New("cannot save connection without process")
} }
if conn.DatabaseKey() == "" { if !conn.KeyIsSet() {
conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain)) conn.SetKey(fmt.Sprintf("network:tree/%d/%s", conn.process.Pid, conn.Domain))
conn.CreateMeta() conn.CreateMeta()
} }
key := conn.makeKey() key := conn.makeKey()
dataLock.RLock() connectionsLock.RLock()
_, ok := connections[key] _, ok := connections[key]
dataLock.RUnlock() connectionsLock.RUnlock()
if !ok { if !ok {
dataLock.Lock() connectionsLock.Lock()
connections[key] = conn connections[key] = conn
dataLock.Unlock() connectionsLock.Unlock()
} }
dbController.PushUpdate(conn) go dbController.PushUpdate(conn)
return nil return nil
} }
// Delete deletes a connection from the storage and propagates the change. // Delete deletes a connection from the storage and propagates the change.
func (conn *Connection) Delete() { func (conn *Connection) Delete() {
dataLock.Lock()
defer dataLock.Unlock()
delete(connections, conn.makeKey())
conn.Lock() conn.Lock()
defer conn.Lock() defer conn.Unlock()
connectionsLock.Lock()
delete(connections, conn.makeKey())
connectionsLock.Unlock()
conn.Meta().Delete() conn.Meta().Delete()
dbController.PushUpdate(conn) go dbController.PushUpdate(conn)
conn.process.RemoveConnection() conn.process.RemoveConnection()
} }
// AddLink applies the connection to the link and increases sets counter and timestamps. // AddLink applies the connection to the link and increases sets counter and timestamps.
func (conn *Connection) AddLink(link *Link) { func (conn *Connection) AddLink(link *Link) {
link.Lock() link.Lock()
defer link.Unlock()
link.connection = conn link.connection = conn
link.Verdict = conn.Verdict link.Verdict = conn.Verdict
link.Inspect = conn.Inspect link.Inspect = conn.Inspect
link.Unlock()
link.Save() link.Save()
conn.Lock() conn.Lock()
defer conn.Unlock()
conn.LinkCount++ conn.LinkCount++
conn.LastLinkEstablished = time.Now().Unix() conn.LastLinkEstablished = time.Now().Unix()
if conn.FirstLinkEstablished == 0 { if conn.FirstLinkEstablished == 0 {
conn.FirstLinkEstablished = conn.LastLinkEstablished conn.FirstLinkEstablished = conn.LastLinkEstablished
} }
conn.Unlock()
conn.Save() conn.Save()
} }
@@ -276,6 +293,7 @@ func (conn *Connection) AddLink(link *Link) {
func (conn *Connection) RemoveLink() { func (conn *Connection) RemoveLink() {
conn.Lock() conn.Lock()
defer conn.Unlock() defer conn.Unlock()
if conn.LinkCount > 0 { if conn.LinkCount > 0 {
conn.LinkCount-- conn.LinkCount--
} }
@@ -283,13 +301,16 @@ func (conn *Connection) RemoveLink() {
// String returns a string representation of Connection. // String returns a string representation of Connection.
func (conn *Connection) String() string { func (conn *Connection) String() string {
conn.Lock()
defer conn.Unlock()
switch conn.Domain { switch conn.Domain {
case "I": case IncomingHost, IncomingLAN, IncomingInternet, IncomingInvalid:
if conn.process == nil { if conn.process == nil {
return "? <- *" return "? <- *"
} }
return fmt.Sprintf("%s <- *", conn.process.String()) return fmt.Sprintf("%s <- *", conn.process.String())
case "D": case PeerHost, PeerLAN, PeerInternet, PeerInvalid:
if conn.process == nil { if conn.process == nil {
return "? -> *" return "? -> *"
} }

View File

@@ -14,9 +14,10 @@ import (
) )
var ( var (
links map[string]*Link links = make(map[string]*Link)
connections map[string]*Connection linksLock sync.RWMutex
dataLock sync.RWMutex connections = make(map[string]*Connection)
connectionsLock sync.RWMutex
dbController *database.Controller dbController *database.Controller
) )
@@ -29,9 +30,6 @@ type StorageInterface struct {
// Get returns a database record. // Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
dataLock.RLock()
defer dataLock.RUnlock()
splitted := strings.Split(key, "/") splitted := strings.Split(key, "/")
switch splitted[0] { switch splitted[0] {
case "tree": case "tree":
@@ -45,11 +43,15 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
} }
} }
case 3: case 3:
connectionsLock.RLock()
defer connectionsLock.RUnlock()
conn, ok := connections[splitted[2]] conn, ok := connections[splitted[2]]
if ok { if ok {
return conn, nil return conn, nil
} }
case 4: case 4:
linksLock.RLock()
defer linksLock.RUnlock()
link, ok := links[splitted[3]] link, ok := links[splitted[3]]
if ok { if ok {
return link, nil return link, nil
@@ -77,22 +79,23 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
} }
} }
dataLock.RLock()
defer dataLock.RUnlock()
// connections // connections
connectionsLock.RLock()
for _, conn := range connections { for _, conn := range connections {
if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) { if strings.HasPrefix(conn.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- conn it.Next <- conn
} }
} }
connectionsLock.RUnlock()
// links // links
linksLock.RLock()
for _, link := range links { for _, link := range links {
if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) { if strings.HasPrefix(link.DatabaseKey(), q.DatabaseKeyPrefix()) {
it.Next <- link it.Next <- link
} }
} }
linksLock.RUnlock()
it.Finish(nil) it.Finish(nil)
} }

View File

@@ -24,6 +24,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
var nameservers []Nameserver var nameservers []Nameserver
var err error var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil { if dbusConn == nil {
dbusConn, err = dbus.SystemBus() dbusConn, err = dbus.SystemBus()
} }
@@ -158,6 +161,9 @@ func getNameserversFromDbus() ([]Nameserver, error) {
func getConnectivityStateFromDbus() (uint8, error) { func getConnectivityStateFromDbus() (uint8, error) {
var err error var err error
dbusConnLock.Lock()
defer dbusConnLock.Unlock()
if dbusConn == nil { if dbusConn == nil {
dbusConn, err = dbus.SystemBus() dbusConn, err = dbus.SystemBus()
} }

View File

@@ -46,16 +46,33 @@ type Link struct {
// Connection returns the Connection the Link is part of // Connection returns the Connection the Link is part of
func (link *Link) Connection() *Connection { func (link *Link) Connection() *Connection {
link.Lock()
defer link.Unlock()
return link.connection return link.connection
} }
// GetVerdict returns the current verdict.
func (link *Link) GetVerdict() Verdict {
link.Lock()
defer link.Unlock()
return link.Verdict
}
// FirewallHandlerIsSet returns whether a firewall handler is set or not // FirewallHandlerIsSet returns whether a firewall handler is set or not
func (link *Link) FirewallHandlerIsSet() bool { func (link *Link) FirewallHandlerIsSet() bool {
link.Lock()
defer link.Unlock()
return link.firewallHandler != nil return link.firewallHandler != nil
} }
// SetFirewallHandler sets the firewall handler for this link // SetFirewallHandler sets the firewall handler for this link
func (link *Link) SetFirewallHandler(handler FirewallHandler) { func (link *Link) SetFirewallHandler(handler FirewallHandler) {
link.Lock()
defer link.Unlock()
if link.firewallHandler == nil { if link.firewallHandler == nil {
link.firewallHandler = handler link.firewallHandler = handler
link.pktQueue = make(chan packet.Packet, 1000) link.pktQueue = make(chan packet.Packet, 1000)
@@ -67,16 +84,22 @@ func (link *Link) SetFirewallHandler(handler FirewallHandler) {
// StopFirewallHandler unsets the firewall handler // StopFirewallHandler unsets the firewall handler
func (link *Link) StopFirewallHandler() { func (link *Link) StopFirewallHandler() {
link.Lock()
link.firewallHandler = nil
link.Unlock()
link.pktQueue <- nil link.pktQueue <- nil
} }
// HandlePacket queues packet of Link for handling // HandlePacket queues packet of Link for handling
func (link *Link) HandlePacket(pkt packet.Packet) { func (link *Link) HandlePacket(pkt packet.Packet) {
link.Lock()
defer link.Unlock()
if link.firewallHandler != nil { if link.firewallHandler != nil {
link.pktQueue <- pkt link.pktQueue <- pkt
return return
} }
log.Criticalf("network: link %s does not have a firewallHandler (maybe it's a copy), dropping packet", link) log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link)
pkt.Drop() pkt.Drop()
} }
@@ -88,7 +111,7 @@ func (link *Link) Accept(reason string) {
// Deny blocks or drops the link depending on the connection direction and adds the given reason. // Deny blocks or drops the link depending on the connection direction and adds the given reason.
func (link *Link) Deny(reason string) { func (link *Link) Deny(reason string) {
if link.connection.Direction { if link.connection != nil && link.connection.Direction {
link.Drop(reason) link.Drop(reason)
} else { } else {
link.Block(reason) link.Block(reason)
@@ -107,6 +130,17 @@ func (link *Link) Drop(reason string) {
link.UpdateVerdict(DROP) link.UpdateVerdict(DROP)
} }
// RerouteToNameserver reroutes the link to the portmaster nameserver.
func (link *Link) RerouteToNameserver() {
link.UpdateVerdict(RerouteToNameserver)
}
// RerouteToTunnel reroutes the link to the tunnel entrypoint and adds the given reason for accepting the connection.
func (link *Link) RerouteToTunnel(reason string) {
link.AddReason(reason)
link.UpdateVerdict(RerouteToTunnel)
}
// UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts // UpdateVerdict sets a new verdict for this link, making sure it does not interfere with previous verdicts
func (link *Link) UpdateVerdict(newVerdict Verdict) { func (link *Link) UpdateVerdict(newVerdict Verdict) {
link.Lock() link.Lock()
@@ -114,7 +148,7 @@ func (link *Link) UpdateVerdict(newVerdict Verdict) {
if newVerdict > link.Verdict { if newVerdict > link.Verdict {
link.Verdict = newVerdict link.Verdict = newVerdict
link.Save() go link.Save()
} }
} }
@@ -138,54 +172,103 @@ func (link *Link) packetHandler() {
for { for {
pkt := <-link.pktQueue pkt := <-link.pktQueue
if pkt == nil { if pkt == nil {
break return
}
link.Lock()
fwH := link.firewallHandler
link.Unlock()
if fwH != nil {
fwH(pkt, link)
} else {
link.ApplyVerdict(pkt)
}
}
}
// ApplyVerdict appies the link verdict to a packet.
func (link *Link) ApplyVerdict(pkt packet.Packet) {
link.Lock()
defer link.Unlock()
if link.VerdictPermanent {
switch link.Verdict {
case ACCEPT:
pkt.PermanentAccept()
case BLOCK:
pkt.PermanentBlock()
case DROP:
pkt.PermanentDrop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
}
} else {
switch link.Verdict {
case ACCEPT:
pkt.Accept()
case BLOCK:
pkt.Block()
case DROP:
pkt.Drop()
case RerouteToNameserver:
pkt.RerouteToNameserver()
case RerouteToTunnel:
pkt.RerouteToTunnel()
default:
pkt.Drop()
} }
link.firewallHandler(pkt, link)
} }
link.firewallHandler = nil
} }
// Save saves the link object in the storage and propagates the change. // Save saves the link object in the storage and propagates the change.
func (link *Link) Save() error { func (link *Link) Save() error {
link.Lock()
defer link.Unlock()
if link.connection == nil { if link.connection == nil {
return errors.New("cannot save link without connection") return errors.New("cannot save link without connection")
} }
if link.DatabaseKey() == "" { if !link.KeyIsSet() {
link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID)) link.SetKey(fmt.Sprintf("network:tree/%d/%s/%s", link.connection.Process().Pid, link.connection.Domain, link.ID))
link.CreateMeta() link.CreateMeta()
} }
dataLock.RLock() linksLock.RLock()
_, ok := links[link.ID] _, ok := links[link.ID]
dataLock.RUnlock() linksLock.RUnlock()
if !ok { if !ok {
dataLock.Lock() linksLock.Lock()
links[link.ID] = link links[link.ID] = link
dataLock.Unlock() linksLock.Unlock()
} }
dbController.PushUpdate(link) go dbController.PushUpdate(link)
return nil return nil
} }
// Delete deletes a link from the storage and propagates the change. // Delete deletes a link from the storage and propagates the change.
func (link *Link) Delete() { func (link *Link) Delete() {
dataLock.Lock()
defer dataLock.Unlock()
delete(links, link.ID)
link.Lock() link.Lock()
defer link.Lock() defer link.Unlock()
linksLock.Lock()
delete(links, link.ID)
linksLock.Unlock()
link.Meta().Delete() link.Meta().Delete()
dbController.PushUpdate(link) go dbController.PushUpdate(link)
link.connection.RemoveLink() link.connection.RemoveLink()
} }
// GetLink fetches a Link from the database from the default namespace for this object // GetLink fetches a Link from the database from the default namespace for this object
func GetLink(id string) (*Link, bool) { func GetLink(id string) (*Link, bool) {
dataLock.RLock() linksLock.RLock()
defer dataLock.RUnlock() defer linksLock.RUnlock()
link, ok := links[id] link, ok := links[id]
return link, ok return link, ok
@@ -215,6 +298,7 @@ func CreateLinkFromPacket(pkt packet.Packet) *Link {
func (link *Link) GetActiveInspectors() []bool { func (link *Link) GetActiveInspectors() []bool {
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
return link.activeInspectors return link.activeInspectors
} }
@@ -222,6 +306,7 @@ func (link *Link) GetActiveInspectors() []bool {
func (link *Link) SetActiveInspectors(new []bool) { func (link *Link) SetActiveInspectors(new []bool) {
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
link.activeInspectors = new link.activeInspectors = new
} }
@@ -229,6 +314,7 @@ func (link *Link) SetActiveInspectors(new []bool) {
func (link *Link) GetInspectorData() map[uint8]interface{} { func (link *Link) GetInspectorData() map[uint8]interface{} {
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
return link.inspectorData return link.inspectorData
} }
@@ -236,11 +322,15 @@ func (link *Link) GetInspectorData() map[uint8]interface{} {
func (link *Link) SetInspectorData(new map[uint8]interface{}) { func (link *Link) SetInspectorData(new map[uint8]interface{}) {
link.Lock() link.Lock()
defer link.Unlock() defer link.Unlock()
link.inspectorData = new link.inspectorData = new
} }
// String returns a string representation of Link. // String returns a string representation of Link.
func (link *Link) String() string { func (link *Link) String() string {
link.Lock()
defer link.Unlock()
if link.connection == nil { if link.connection == nil {
return fmt.Sprintf("? <-> %s", link.RemoteAddress) return fmt.Sprintf("? <-> %s", link.RemoteAddress)
} }

View File

@@ -15,8 +15,8 @@ const (
Invalid Invalid
) )
// ClassifyAddress returns the classification for the given IP address. // ClassifyIP returns the classification for the given IP address.
func ClassifyAddress(ip net.IP) int8 { func ClassifyIP(ip net.IP) int8 {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
// IPv4 // IPv4
switch { switch {
@@ -73,12 +73,12 @@ func ClassifyAddress(ip net.IP) int8 {
// IPIsLocalhost returns whether the IP refers to the host itself. // IPIsLocalhost returns whether the IP refers to the host itself.
func IPIsLocalhost(ip net.IP) bool { func IPIsLocalhost(ip net.IP) bool {
return ClassifyAddress(ip) == HostLocal return ClassifyIP(ip) == HostLocal
} }
// IPIsLAN returns true if the given IP is a site-local or link-local address. // IPIsLAN returns true if the given IP is a site-local or link-local address.
func IPIsLAN(ip net.IP) bool { func IPIsLAN(ip net.IP) bool {
switch ClassifyAddress(ip) { switch ClassifyIP(ip) {
case SiteLocal: case SiteLocal:
return true return true
case LinkLocal: case LinkLocal:
@@ -90,15 +90,15 @@ func IPIsLAN(ip net.IP) bool {
// IPIsGlobal returns true if the given IP is a global address. // IPIsGlobal returns true if the given IP is a global address.
func IPIsGlobal(ip net.IP) bool { func IPIsGlobal(ip net.IP) bool {
return ClassifyAddress(ip) == Global return ClassifyIP(ip) == Global
} }
// IPIsLinkLocal returns true if the given IP is a link-local address. // IPIsLinkLocal returns true if the given IP is a link-local address.
func IPIsLinkLocal(ip net.IP) bool { func IPIsLinkLocal(ip net.IP) bool {
return ClassifyAddress(ip) == LinkLocal return ClassifyIP(ip) == LinkLocal
} }
// IPIsSiteLocal returns true if the given IP is a site-local address. // IPIsSiteLocal returns true if the given IP is a site-local address.
func IPIsSiteLocal(ip net.IP) bool { func IPIsSiteLocal(ip net.IP) bool {
return ClassifyAddress(ip) == SiteLocal return ClassifyIP(ip) == SiteLocal
} }

View File

@@ -13,7 +13,7 @@ func TestIPClassification(t *testing.T) {
} }
func testClassification(t *testing.T, ip net.IP, expectedClassification int8) { func testClassification(t *testing.T, ip net.IP, expectedClassification int8) {
c := ClassifyAddress(ip) c := ClassifyIP(ip)
if c != expectedClassification { if c != expectedClassification {
t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification)) t.Errorf("%s is %s, expected %s", ip, classificationString(c), classificationString(expectedClassification))
} }

View File

@@ -12,6 +12,8 @@ const (
ACCEPT ACCEPT
BLOCK BLOCK
DROP DROP
RerouteToNameserver
RerouteToTunnel
) )
// Packer Directions // Packer Directions

View File

@@ -45,7 +45,7 @@ func (p *Process) Save() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if p.DatabaseKey() == "" { if !p.KeyIsSet() {
p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid)) p.SetKey(fmt.Sprintf("network:tree/%d", p.Pid))
p.CreateMeta() p.CreateMeta()
} }
@@ -61,21 +61,22 @@ func (p *Process) Save() {
} }
if dbControllerFlag.IsSet() { if dbControllerFlag.IsSet() {
dbController.PushUpdate(p) go dbController.PushUpdate(p)
} }
} }
// Delete deletes a process from the storage and propagates the change. // Delete deletes a process from the storage and propagates the change.
func (p *Process) Delete() { func (p *Process) Delete() {
processesLock.Lock()
defer processesLock.Unlock()
delete(processes, p.Pid)
p.Lock() p.Lock()
defer p.Lock() defer p.Lock()
p.Meta().Delete()
processesLock.Lock()
delete(processes, p.Pid)
processesLock.Unlock()
p.Meta().Delete()
if dbControllerFlag.IsSet() { if dbControllerFlag.IsSet() {
dbController.PushUpdate(p) go dbController.PushUpdate(p)
} }
profile.DeactivateProfileSet(p.profileSet) profile.DeactivateProfileSet(p.profileSet)
@@ -88,9 +89,11 @@ func CleanProcessStorage(thresholdDuration time.Duration) {
threshold := time.Now().Add(-thresholdDuration).Unix() threshold := time.Now().Add(-thresholdDuration).Unix()
for _, p := range processes { for _, p := range processes {
p.Lock()
if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 { if p.FirstConnectionEstablished < threshold && p.ConnectionCount == 0 {
p.Delete() go p.Delete()
} }
p.Unlock()
} }
} }

View File

@@ -4,13 +4,14 @@ import (
"errors" "errors"
"net" "net"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network/packet" "github.com/Safing/portmaster/network/packet"
) )
// Errors // Errors
var ( var (
ErrConnectionNotFound = errors.New("could not find connection") ErrConnectionNotFound = errors.New("could not find connection in system state tables")
ErrProcessNotFound = errors.New("could not find process") ErrProcessNotFound = errors.New("could not find process in system state tables")
) )
// GetPidByPacket returns the pid of the owner of the packet. // GetPidByPacket returns the pid of the owner of the packet.
@@ -57,18 +58,23 @@ func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, er
var pid int var pid int
pid, direction, err = GetPidByPacket(pkt) pid, direction, err = GetPidByPacket(pkt)
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
if err != nil { if err != nil {
return nil, direction, err return nil, direction, err
} }
if pid < 0 {
return nil, direction, ErrConnectionNotFound
}
process, err = GetOrFindProcess(pid) process, err = GetOrFindProcess(pid)
if err != nil { if err != nil {
return nil, direction, err return nil, direction, err
} }
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, direction, nil return process, direction, nil
} }
@@ -113,6 +119,11 @@ func GetProcessByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, re
return nil, err return nil, err
} }
err = process.FindProfiles()
if err != nil {
log.Errorf("failed to find profiles for process %s: %s", process.String(), err)
}
return process, nil return process, nil
} }

View File

@@ -1,52 +1,59 @@
package process package process
import ( import (
"fmt"
"github.com/Safing/portbase/database" "github.com/Safing/portbase/database"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/log" "github.com/Safing/portbase/log"
"github.com/Safing/portmaster/profile" "github.com/Safing/portmaster/profile"
"github.com/Safing/portmaster/profile/index" )
var (
profileDB = database.NewInterface(nil)
) )
// FindProfiles finds and assigns a profile set to the process. // FindProfiles finds and assigns a profile set to the process.
func (p *Process) FindProfiles() error { func (p *Process) FindProfiles() error {
// Get fingerprints of process p.Lock()
defer p.Unlock()
// Check if user profile already exists, else create new // only find profiles if not already done.
pathIdentifier := profile.GetPathIdentifier(p.Path) if p.profileSet != nil {
indexRecord, err := index.Get(pathIdentifier) return nil
if err != nil && err != database.ErrNotFound {
log.Errorf("process: could not get profile index for %s: %s", pathIdentifier, err)
} }
var possibleProfiles []*profile.Profile // User Profile
if indexRecord != nil { it, err := profileDB.Query(query.New(profile.MakeProfileKey(profile.UserNamespace, "")).Where(query.Where("LinkedPath", query.SameAs, p.Path)))
for _, profileID := range indexRecord.UserProfiles { if err != nil {
prof, err := profile.Get(profileID) return err
if err != nil { }
log.Errorf("process: failed to load profile %s: %s", profileID, err)
} var userProfile *profile.Profile
possibleProfiles = append(possibleProfiles, prof) for r := range it.Next {
} it.Cancel()
userProfile, err = profile.EnsureProfile(r)
if err != nil {
return err
} }
break
}
if it.Err() != nil {
return it.Err()
} }
prof := selectProfile(p, possibleProfiles) // create new profile if it does not exist.
if prof == nil { if userProfile == nil {
// create new profile // create new profile
prof := profile.New() userProfile = profile.New()
prof.Name = p.ExecName userProfile.Name = p.ExecName
prof.AddFingerprint(&profile.Fingerprint{ userProfile.LinkedPath = p.Path
Type: "full_path",
Value: p.Path,
})
// TODO: maybe add sha256_sum?
prof.MarkUsed()
prof.Save()
} }
if userProfile.MarkUsed() {
userProfile.Save(profile.UserNamespace)
}
// Stamp
// Find/Re-evaluate Stamp profile // Find/Re-evaluate Stamp profile
// 1. check linked stamp profile // 1. check linked stamp profile
// 2. if last check is was more than a week ago, fetch from stamp: // 2. if last check is was more than a week ago, fetch from stamp:
@@ -56,13 +63,9 @@ func (p *Process) FindProfiles() error {
// 6. link stamp profile to user profile // 6. link stamp profile to user profile
// FIXME: implement! // FIXME: implement!
if prof.MarkUsed() { p.UserProfileKey = userProfile.Key()
prof.Save() p.profileSet = profile.NewSet(userProfile, nil)
} go p.Save()
p.UserProfileKey = prof.Key()
p.profileSet = profile.NewSet(prof, nil)
p.Save()
return nil return nil
} }

View File

@@ -13,14 +13,19 @@ const (
NoProcess NoProcess
) )
func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { var (
waitTime = 15 * time.Millisecond
)
// GetPidOfConnection returns the PID of the given connection.
func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) uid, inode, ok := getConnectionSocket(localIP, localPort, protocol)
if !ok { if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol) uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
for i := 0; i < 3 && !ok; i++ { for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again // give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think") // log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond) time.Sleep(waitTime)
uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) uid, inode, ok = getConnectionSocket(localIP, localPort, protocol)
if !ok { if !ok {
uid, inode, ok = getListeningSocket(localIP, localPort, protocol) uid, inode, ok = getListeningSocket(localIP, localPort, protocol)
@@ -30,27 +35,48 @@ func GetPidOfConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid
return -1, NoSocket return -1, NoSocket
} }
} }
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ { for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again // give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think") // log.Tracef("process: giving kernel some time to think")
time.Sleep(15 * time.Millisecond) time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
} }
if !ok { if !ok {
return -1, NoProcess return -1, NoProcess
} }
return return
} }
func GetPidOfIncomingConnection(localIP *net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { // GetPidOfConnection returns the PID of the given incoming connection.
func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) {
uid, inode, ok := getListeningSocket(localIP, localPort, protocol) uid, inode, ok := getListeningSocket(localIP, localPort, protocol)
if !ok { if !ok {
return -1, NoSocket // for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version.
switch protocol {
case TCP4:
uid, inode, ok = getListeningSocket(localIP, localPort, TCP6)
case UDP4:
uid, inode, ok = getListeningSocket(localIP, localPort, UDP6)
}
if !ok {
return -1, NoSocket
}
} }
pid, ok = GetPidOfInode(uid, inode) pid, ok = GetPidOfInode(uid, inode)
for i := 0; i < 3 && !ok; i++ {
// give kernel some time, then try again
// log.Tracef("process: giving kernel some time to think")
time.Sleep(waitTime)
pid, ok = GetPidOfInode(uid, inode)
}
if !ok { if !ok {
return -1, NoProcess return -1, NoProcess
} }
return return
} }

View File

@@ -6,39 +6,39 @@ import (
) )
func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP4, localIP, localPort, direction) return search(TCP4, localIP, localPort, pktDirection)
} }
func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(TCP6, localIP, localPort, direction) return search(TCP6, localIP, localPort, pktDirection)
} }
func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP4, localIP, localPort, direction) return search(UDP4, localIP, localPort, pktDirection)
} }
func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) {
return search(UDP6, localIP, localPort, direction) return search(UDP6, localIP, localPort, pktDirection)
} }
func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) {
var status uint8 var status uint8
if pktDirection { if pktDirection {
pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
if pid >= 0 { if pid >= 0 {
return pid, true, nil return pid, true, nil
} }
// pid, status = GetPidOfConnection(&localIP, localPort, protocol) // pid, status = GetPidOfConnection(localIP, localPort, protocol)
// if pid >= 0 { // if pid >= 0 {
// return pid, false, nil // return pid, false, nil
// } // }
} else { } else {
pid, status = GetPidOfConnection(&localIP, localPort, protocol) pid, status = GetPidOfConnection(localIP, localPort, protocol)
if pid >= 0 { if pid >= 0 {
return pid, false, nil return pid, false, nil
} }
// pid, status = GetPidOfIncomingConnection(&localIP, localPort, protocol) // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol)
// if pid >= 0 { // if pid >= 0 {
// return pid, true, nil // return pid, true, nil
// } // }

View File

@@ -81,7 +81,7 @@ var (
globalListeningUDP6 = make(map[uint16][]int) globalListeningUDP6 = make(map[uint16][]int)
) )
func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int, int, bool) { func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) {
// listeningSocketsLock.Lock() // listeningSocketsLock.Lock()
// defer listeningSocketsLock.Unlock() // defer listeningSocketsLock.Unlock()
@@ -98,10 +98,10 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]}))
case TCP6: case TCP6:
procFile = TCP6Data procFile = TCP6Data
localIPHex = hex.EncodeToString([]byte(*localIP)) localIPHex = hex.EncodeToString([]byte(localIP))
case UDP6: case UDP6:
procFile = UDP6Data procFile = UDP6Data
localIPHex = hex.EncodeToString([]byte(*localIP)) localIPHex = hex.EncodeToString([]byte(localIP))
} }
localPortHex := fmt.Sprintf("%04X", localPort) localPortHex := fmt.Sprintf("%04X", localPort)
@@ -162,38 +162,38 @@ func getConnectionSocket(localIP *net.IP, localPort uint16, protocol uint8) (int
} }
func getListeningSocket(localIP *net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) {
listeningSocketsLock.Lock() listeningSocketsLock.Lock()
defer listeningSocketsLock.Unlock() defer listeningSocketsLock.Unlock()
var addressListening *map[string][]int var addressListening map[string][]int
var globalListening *map[uint16][]int var globalListening map[uint16][]int
switch protocol { switch protocol {
case TCP4: case TCP4:
addressListening = &addressListeningTCP4 addressListening = addressListeningTCP4
globalListening = &globalListeningTCP4 globalListening = globalListeningTCP4
case UDP4: case UDP4:
addressListening = &addressListeningUDP4 addressListening = addressListeningUDP4
globalListening = &globalListeningUDP4 globalListening = globalListeningUDP4
case TCP6: case TCP6:
addressListening = &addressListeningTCP6 addressListening = addressListeningTCP6
globalListening = &globalListeningTCP6 globalListening = globalListeningTCP6
case UDP6: case UDP6:
addressListening = &addressListeningUDP6 addressListening = addressListeningUDP6
globalListening = &globalListeningUDP6 globalListening = globalListeningUDP6
} }
data, ok := (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok { if !ok {
data, ok = (*globalListening)[localPort] data, ok = globalListening[localPort]
} }
if ok { if ok {
return data[0], data[1], true return data[0], data[1], true
} }
updateListeners(protocol) updateListeners(protocol)
data, ok = (*addressListening)[fmt.Sprintf("%s:%d", localIP, localPort)] data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)]
if !ok { if !ok {
data, ok = (*globalListening)[localPort] data, ok = globalListening[localPort]
} }
if ok { if ok {
return data[0], data[1], true return data[0], data[1], true
@@ -206,7 +206,7 @@ func procDelimiter(c rune) bool {
return unicode.IsSpace(c) || c == ':' return unicode.IsSpace(c) || c == ':'
} }
func convertIPv4(data string) *net.IP { func convertIPv4(data string) net.IP {
decoded, err := hex.DecodeString(data) decoded, err := hex.DecodeString(data)
if err != nil { if err != nil {
log.Warningf("process: could not parse IPv4 %s: %s", data, err) log.Warningf("process: could not parse IPv4 %s: %s", data, err)
@@ -217,10 +217,10 @@ func convertIPv4(data string) *net.IP {
return nil return nil
} }
ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
return &ip return ip
} }
func convertIPv6(data string) *net.IP { func convertIPv6(data string) net.IP {
decoded, err := hex.DecodeString(data) decoded, err := hex.DecodeString(data)
if err != nil { if err != nil {
log.Warningf("process: could not parse IPv6 %s: %s", data, err) log.Warningf("process: could not parse IPv6 %s: %s", data, err)
@@ -231,7 +231,7 @@ func convertIPv6(data string) *net.IP {
return nil return nil
} }
ip := net.IP(decoded) ip := net.IP(decoded)
return &ip return ip
} }
func updateListeners(protocol uint8) { func updateListeners(protocol uint8) {
@@ -247,7 +247,7 @@ func updateListeners(protocol uint8) {
} }
} }
func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) *net.IP) (map[string][]int, map[uint16][]int) { func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) {
addressListening := make(map[string][]int) addressListening := make(map[string][]int)
globalListening := make(map[uint16][]int) globalListening := make(map[uint16][]int)
@@ -312,6 +312,7 @@ func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter
return addressListening, globalListening return addressListening, globalListening
} }
// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS.
func GetActiveConnectionIDs() []string { func GetActiveConnectionIDs() []string {
var connections []string var connections []string
@@ -323,7 +324,7 @@ func GetActiveConnectionIDs() []string {
return connections return connections
} }
func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) *net.IP) []string { func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string {
var connections []string var connections []string
// open file // open file

View File

@@ -22,14 +22,14 @@ func TestSockets(t *testing.T) {
t.Logf("addressListeningUDP6: %v", addressListeningUDP6) t.Logf("addressListeningUDP6: %v", addressListeningUDP6)
t.Logf("globalListeningUDP6: %v", globalListeningUDP6) t.Logf("globalListeningUDP6: %v", globalListeningUDP6)
getListeningSocket(&net.IPv4zero, 53, TCP4) getListeningSocket(net.IPv4zero, 53, TCP4)
getListeningSocket(&net.IPv4zero, 53, UDP4) getListeningSocket(net.IPv4zero, 53, UDP4)
getListeningSocket(&net.IPv6zero, 53, TCP6) getListeningSocket(net.IPv6zero, 53, TCP6)
getListeningSocket(&net.IPv6zero, 53, UDP6) getListeningSocket(net.IPv6zero, 53, UDP6)
// spotify: 192.168.0.102:5353 192.121.140.65:80 // spotify: 192.168.0.102:5353 192.121.140.65:80
localIP := net.IPv4(192, 168, 127, 10) localIP := net.IPv4(192, 168, 127, 10)
uid, inode, ok := getConnectionSocket(&localIP, 46634, TCP4) uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4)
t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok)
activeConnectionIDs := GetActiveConnectionIDs() activeConnectionIDs := GetActiveConnectionIDs()

View File

@@ -49,11 +49,17 @@ type Process struct {
// ProfileSet returns the assigned profile set. // ProfileSet returns the assigned profile set.
func (p *Process) ProfileSet() *profile.Set { func (p *Process) ProfileSet() *profile.Set {
p.Lock()
defer p.Unlock()
return p.profileSet return p.profileSet
} }
// Strings returns a string represenation of process. // Strings returns a string represenation of process.
func (p *Process) String() string { func (p *Process) String() string {
p.Lock()
defer p.Unlock()
if p == nil { if p == nil {
return "?" return "?"
} }
@@ -64,6 +70,7 @@ func (p *Process) String() string {
func (p *Process) AddConnection() { func (p *Process) AddConnection() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
p.ConnectionCount++ p.ConnectionCount++
p.LastConnectionEstablished = time.Now().Unix() p.LastConnectionEstablished = time.Now().Unix()
if p.FirstConnectionEstablished == 0 { if p.FirstConnectionEstablished == 0 {
@@ -75,6 +82,7 @@ func (p *Process) AddConnection() {
func (p *Process) RemoveConnection() { func (p *Process) RemoveConnection() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if p.ConnectionCount > 0 { if p.ConnectionCount > 0 {
p.ConnectionCount-- p.ConnectionCount--
} }
@@ -235,9 +243,8 @@ func GetOrFindProcess(pid int) (*Process, error) {
// Executable Information // Executable Information
// FIXME: use os specific path seperator // FIXME: use os specific path seperator
splittedPath := strings.Split("/", new.Path) splittedPath := strings.Split(new.Path, "/")
new.ExecName = strings.ToTitle(splittedPath[len(splittedPath)-1]) new.ExecName = splittedPath[len(splittedPath)-1]
} }
// save to storage // save to storage

View File

@@ -12,9 +12,9 @@ import (
// Namespaces // Namespaces
const ( const (
userNamespace = "user" UserNamespace = "user"
stampNamespace = "stamp" StampNamespace = "stamp"
specialNamespace = "special" SpecialNamespace = "special"
) )
var ( var (

View File

@@ -90,7 +90,7 @@ func Get(fpType, id string) (*ProfileIndex, error) {
// Save saves the Identifiers to the database // Save saves the Identifiers to the database
func (pi *ProfileIndex) Save() error { func (pi *ProfileIndex) Save() error {
if pi.Key() == "" { if !pi.KeyIsSet() {
if pi.ID != "" { if pi.ID != "" {
pi.SetKey(makeIndexRecordKey(pi.ID)) pi.SetKey(makeIndexRecordKey(pi.ID))
} else { } else {

View File

@@ -7,7 +7,7 @@ var (
) )
func init() { func init() {
modules.Register("profile", nil, start, stop, "database") modules.Register("profile", nil, start, stop, "global", "database")
} }
func start() error { func start() error {

View File

@@ -33,7 +33,9 @@ func TestPorts(t *testing.T) {
}, },
}, },
} }
if ports.String() != "TCP:[permit:22], <UDP:[deny:80-81], 93:[permit:93]" { if ports.String() != "TCP:[permit:22], <UDP:[deny:80-81], 93:[permit:93]" &&
ports.String() != "93:[permit:93], TCP:[permit:22], <UDP:[deny:80-81]" &&
ports.String() != "<UDP:[deny:80-81], 93:[permit:93], TCP:[permit:22]" {
t.Errorf("unexpected result: %s", ports.String()) t.Errorf("unexpected result: %s", ports.String())
} }

View File

@@ -28,6 +28,11 @@ type Profile struct {
// Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for a database path or "e:" for the encoded data. // Icon is a path to the icon and is either prefixed "f:" for filepath, "d:" for a database path or "e:" for the encoded data.
Icon string Icon string
// User Profile Only
LinkedPath string `json:",omitempty"`
StampProfileID string `json:",omitempty"`
StampProfileAssigned int64 `json:",omitempty"`
// Fingerprints // Fingerprints
Fingerprints []*Fingerprint Fingerprints []*Fingerprint
@@ -37,24 +42,22 @@ type Profile struct {
Domains Domains Domains Domains
Ports Ports Ports Ports
// User Profile Only
CoupledPath string `json:",omitempty"`
StampProfileKey string `json:",omitempty"`
StampProfileAssigned int64 `json:",omitempty"`
// If a Profile is declared as a Framework (i.e. an Interpreter and the likes), then the real process must be found // If a Profile is declared as a Framework (i.e. an Interpreter and the likes), then the real process must be found
// Framework *Framework `json:",omitempty bson:",omitempty"` // Framework *Framework `json:",omitempty bson:",omitempty"`
// When this Profile was approximately last used (for performance reasons not every single usage is saved) // When this Profile was approximately last used (for performance reasons not every single usage is saved)
Created int64
ApproxLastUsed int64 ApproxLastUsed int64
} }
// New returns a new Profile. // New returns a new Profile.
func New() *Profile { func New() *Profile {
return &Profile{} return &Profile{
Created: time.Now().Unix(),
}
} }
func makeProfileKey(namespace, ID string) string { func MakeProfileKey(namespace, ID string) string {
return fmt.Sprintf("core:profiles/%s/%s", namespace, ID) return fmt.Sprintf("core:profiles/%s/%s", namespace, ID)
} }
@@ -68,11 +71,11 @@ func (profile *Profile) Save(namespace string) error {
profile.ID = u.String() profile.ID = u.String()
} }
if profile.Key() == "" { if !profile.KeyIsSet() {
if namespace == "" { if namespace == "" {
return fmt.Errorf("no key or namespace defined for profile %s", profile.String()) return fmt.Errorf("no key or namespace defined for profile %s", profile.String())
} }
profile.SetKey(makeProfileKey(namespace, profile.ID)) profile.SetKey(MakeProfileKey(namespace, profile.ID))
} }
return profileDB.Put(profile) return profileDB.Put(profile)
@@ -99,23 +102,24 @@ func (profile *Profile) DetailedString() string {
// GetUserProfile loads a profile from the database. // GetUserProfile loads a profile from the database.
func GetUserProfile(ID string) (*Profile, error) { func GetUserProfile(ID string) (*Profile, error) {
return getProfile(userNamespace, ID) return getProfile(UserNamespace, ID)
} }
// GetStampProfile loads a profile from the database. // GetStampProfile loads a profile from the database.
func GetStampProfile(ID string) (*Profile, error) { func GetStampProfile(ID string) (*Profile, error) {
return getProfile(stampNamespace, ID) return getProfile(StampNamespace, ID)
} }
func getProfile(namespace, ID string) (*Profile, error) { func getProfile(namespace, ID string) (*Profile, error) {
r, err := profileDB.Get(makeProfileKey(namespace, ID)) r, err := profileDB.Get(MakeProfileKey(namespace, ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ensureProfile(r) return EnsureProfile(r)
} }
func ensureProfile(r record.Record) (*Profile, error) { // EnsureProfile ensures that the given record is a *Profile, and returns it.
func EnsureProfile(r record.Record) (*Profile, error) {
// unwrap // unwrap
if r.IsWrapped() { if r.IsWrapped() {
// only allocate a new struct, if we need it // only allocate a new struct, if we need it

View File

@@ -21,8 +21,8 @@ type Set struct {
// Stamp // Stamp
// Default // Default
securityLevel uint8 combinedSecurityLevel uint8
independent bool independent bool
} }
// NewSet returns a new profile set with given the profiles. // NewSet returns a new profile set with given the profiles.
@@ -40,6 +40,11 @@ func NewSet(user, stamp *Profile) *Set {
return new return new
} }
// UserProfile returns the user profile.
func (set *Set) UserProfile() *Profile {
return set.profiles[0]
}
// Update gets the new global and default profile and updates the independence status. It must be called when reusing a profile set for a series of calls. // Update gets the new global and default profile and updates the independence status. It must be called when reusing a profile set for a series of calls.
func (set *Set) Update(securityLevel uint8) { func (set *Set) Update(securityLevel uint8) {
set.Lock() set.Lock()
@@ -52,15 +57,15 @@ func (set *Set) Update(securityLevel uint8) {
set.profiles[3] = fallbackProfile set.profiles[3] = fallbackProfile
// update security level // update security level
profileSecurityLevel := set.getProfileSecurityLevel() profileSecurityLevel := set.getSecurityLevel()
if profileSecurityLevel > securityLevel { if profileSecurityLevel > securityLevel {
set.securityLevel = profileSecurityLevel set.combinedSecurityLevel = profileSecurityLevel
} else { } else {
set.securityLevel = securityLevel set.combinedSecurityLevel = securityLevel
} }
// update independence
set.Unlock() set.Unlock()
// update independence
if set.CheckFlag(Independent) { if set.CheckFlag(Independent) {
set.Lock() set.Lock()
set.independent = true set.independent = true
@@ -72,6 +77,14 @@ func (set *Set) Update(securityLevel uint8) {
} }
} }
// SecurityLevel returns the applicable security level for the profile set.
func (set *Set) SecurityLevel() uint8 {
set.Lock()
defer set.Unlock()
return set.combinedSecurityLevel
}
// GetProfileMode returns the active profile mode. // GetProfileMode returns the active profile mode.
func (set *Set) GetProfileMode() uint8 { func (set *Set) GetProfileMode() uint8 {
switch { switch {
@@ -97,7 +110,7 @@ func (set *Set) CheckFlag(flag uint8) (active bool) {
} }
if profile != nil { if profile != nil {
active, ok := profile.Flags.Check(flag, set.securityLevel) active, ok := profile.Flags.Check(flag, set.combinedSecurityLevel)
if ok { if ok {
return active return active
} }
@@ -153,8 +166,11 @@ func (set *Set) CheckPort(listen bool, protocol uint8, port uint16) (permit, ok
return false, false return false, false
} }
// SecurityLevel returns the highest prioritized security level. // getSecurityLevel returns the highest prioritized security level.
func (set *Set) getProfileSecurityLevel() uint8 { func (set *Set) getSecurityLevel() uint8 {
if set == nil {
return 0
}
for i, profile := range set.profiles { for i, profile := range set.profiles {
if i == 2 { if i == 2 {

View File

@@ -24,7 +24,7 @@ func initSpecialProfiles() (err error) {
return err return err
} }
globalProfile = makeDefaultGlobalProfile() globalProfile = makeDefaultGlobalProfile()
globalProfile.Save(specialNamespace) globalProfile.Save(SpecialNamespace)
} }
fallbackProfile, err = getSpecialProfile("fallback") fallbackProfile, err = getSpecialProfile("fallback")
@@ -33,12 +33,12 @@ func initSpecialProfiles() (err error) {
return err return err
} }
fallbackProfile = makeDefaultFallbackProfile() fallbackProfile = makeDefaultFallbackProfile()
fallbackProfile.Save(specialNamespace) fallbackProfile.Save(SpecialNamespace)
} }
return nil return nil
} }
func getSpecialProfile(ID string) (*Profile, error) { func getSpecialProfile(ID string) (*Profile, error) {
return getProfile(specialNamespace, ID) return getProfile(SpecialNamespace, ID)
} }

View File

@@ -9,7 +9,7 @@ import (
) )
func initUpdateListener() error { func initUpdateListener() error {
sub, err := profileDB.Subscribe(query.New(makeProfileKey(specialNamespace, ""))) sub, err := profileDB.Subscribe(query.New(MakeProfileKey(SpecialNamespace, "")))
if err != nil { if err != nil {
return err return err
} }
@@ -29,7 +29,7 @@ func updateListener(sub *database.Subscription) {
continue continue
} }
profile, err := ensureProfile(r) profile, err := EnsureProfile(r)
if err != nil { if err != nil {
log.Errorf("profile: received update for special profile, but could not read: %s", err) log.Errorf("profile: received update for special profile, but could not read: %s", err)
continue continue
@@ -46,9 +46,9 @@ func updateListener(sub *database.Subscription) {
specialProfileLock.Unlock() specialProfileLock.Unlock()
default: default:
switch { switch {
case strings.HasPrefix(profile.Key(), makeProfileKey(userNamespace, "")): case strings.HasPrefix(profile.Key(), MakeProfileKey(UserNamespace, "")):
updateActiveUserProfile(profile) updateActiveUserProfile(profile)
case strings.HasPrefix(profile.Key(), makeProfileKey(stampNamespace, "")): case strings.HasPrefix(profile.Key(), MakeProfileKey(StampNamespace, "")):
updateActiveStampProfile(profile) updateActiveStampProfile(profile)
} }
} }

25
status/module.go Normal file
View File

@@ -0,0 +1,25 @@
package status
import (
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/modules"
)
func init() {
modules.Register("status", prep, nil, nil)
}
func prep() error {
if CurrentSecurityLevel() == SecurityLevelOff {
log.Infof("switching to default active security level: dynamic")
SetCurrentSecurityLevel(SecurityLevelDynamic)
}
if SelectedSecurityLevel() == SecurityLevelOff {
log.Infof("switching to default selected security level: dynamic")
SetSelectedSecurityLevel(SecurityLevelDynamic)
}
return nil
}