Block until pending verdicts are set. Update deps
This commit is contained in:
@@ -5,10 +5,12 @@ package nfqexp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
pmpacket "github.com/safing/portmaster/network/packet"
|
||||
"github.com/tevino/abool"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/florianl/go-nfqueue"
|
||||
@@ -20,6 +22,9 @@ type Queue struct {
|
||||
nf *nfqueue.Nfqueue
|
||||
packets chan pmpacket.Packet
|
||||
cancelSocketCallback context.CancelFunc
|
||||
|
||||
pendingVerdicts uint64
|
||||
verdictCompleted chan struct{}
|
||||
}
|
||||
|
||||
// New opens a new nfQueue.
|
||||
@@ -30,12 +35,12 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
}
|
||||
cfg := &nfqueue.Config{
|
||||
NfQueue: qid,
|
||||
MaxPacketLen: 0xffff,
|
||||
MaxQueueLen: 0xff,
|
||||
MaxPacketLen: 0xff,
|
||||
MaxQueueLen: 0xffff,
|
||||
AfFamily: uint8(afFamily),
|
||||
Copymode: nfqueue.NfQnlCopyPacket,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 5 * time.Millisecond,
|
||||
WriteTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
nf, err := nfqueue.Open(cfg)
|
||||
@@ -49,6 +54,7 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
nf: nf,
|
||||
packets: make(chan pmpacket.Packet, 1000),
|
||||
cancelSocketCallback: cancel,
|
||||
verdictCompleted: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
fn := func(attrs nfqueue.Attribute) int {
|
||||
@@ -61,10 +67,11 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
}
|
||||
|
||||
pkt := &packet{
|
||||
ID: *attrs.PacketID,
|
||||
queue: q,
|
||||
received: time.Now(),
|
||||
verdictSet: make(chan struct{}),
|
||||
pktID: *attrs.PacketID,
|
||||
queue: q,
|
||||
received: time.Now(),
|
||||
verdictSet: make(chan struct{}),
|
||||
verdictPending: abool.New(),
|
||||
}
|
||||
|
||||
if attrs.Payload != nil {
|
||||
@@ -79,7 +86,7 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
|
||||
select {
|
||||
case q.packets <- pkt:
|
||||
log.Tracef("nfqexp: queued packet %d (%s -> %s) after %s", pkt.ID, pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
|
||||
log.Tracef("nfqexp: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
|
||||
case <-ctx.Done():
|
||||
return 0
|
||||
case <-time.After(time.Second):
|
||||
@@ -90,10 +97,10 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
select {
|
||||
case <-pkt.verdictSet:
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warningf("nfqexp: no verdict set for packet %d (%s -> %s) after %s, dropping", pkt.ID, pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
|
||||
case <-time.After(20 * time.Second):
|
||||
log.Warningf("nfqexp: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
|
||||
if err := pkt.Drop(); err != nil {
|
||||
log.Warningf("nfqexp: failed to apply default-drop to unveridcted packet %d (%s -> %s)", pkt.ID, pkt.Info().Src, pkt.Info().Dst)
|
||||
log.Warningf("nfqexp: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -101,7 +108,32 @@ func New(qid uint16, v6 bool) (*Queue, error) {
|
||||
return 0 // continue calling this fn
|
||||
}
|
||||
|
||||
if err := q.nf.Register(ctx, fn); err != nil {
|
||||
errorFunc := func(e error) int {
|
||||
// embedded interface is required to work-around some
|
||||
// dep-vendoring weirdness
|
||||
if opError, ok := e.(interface {
|
||||
Timeout() bool
|
||||
Temporary() bool
|
||||
}); ok {
|
||||
if opError.Timeout() || opError.Temporary() {
|
||||
c := atomic.LoadUint64(&q.pendingVerdicts)
|
||||
if c > 0 {
|
||||
log.Tracef("nfqexp: waiting for %d pending verdicts", c)
|
||||
|
||||
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
|
||||
<-q.verdictCompleted
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
}
|
||||
log.Errorf("nfqexp: encountered error while receiving packets: %s\n", e.Error())
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := q.nf.RegisterWithErrorFunc(ctx, fn, errorFunc); err != nil {
|
||||
defer q.nf.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,10 +4,13 @@ package nfqexp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/florianl/go-nfqueue"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/safing/portbase/log"
|
||||
pmpacket "github.com/safing/portmaster/network/packet"
|
||||
)
|
||||
@@ -51,10 +54,15 @@ func markToString(mark int) string {
|
||||
// packet implements the packet.Packet interface.
|
||||
type packet struct {
|
||||
pmpacket.Base
|
||||
ID uint32
|
||||
received time.Time
|
||||
queue *Queue
|
||||
verdictSet chan struct{}
|
||||
pktID uint32
|
||||
received time.Time
|
||||
queue *Queue
|
||||
verdictSet chan struct{}
|
||||
verdictPending *abool.AtomicBool
|
||||
}
|
||||
|
||||
func (pkt *packet) ID() string {
|
||||
return fmt.Sprintf("pkt:%d qid:%d", pkt.pktID, pkt.queue.id)
|
||||
}
|
||||
|
||||
// TODO(ppacher): revisit the following behavior:
|
||||
@@ -68,26 +76,44 @@ type packet struct {
|
||||
// raw-socket.
|
||||
//
|
||||
func (pkt *packet) mark(mark int) (err error) {
|
||||
if pkt.verdictPending.SetToIf(false, true) {
|
||||
defer close(pkt.verdictSet)
|
||||
return pkt.setMark(mark)
|
||||
}
|
||||
|
||||
return errors.New("verdict set")
|
||||
}
|
||||
|
||||
func (pkt *packet) setMark(mark int) error {
|
||||
atomic.AddUint64(&pkt.queue.pendingVerdicts, 1)
|
||||
|
||||
defer func() {
|
||||
if x := recover(); x != nil {
|
||||
err = errors.New("verdict set")
|
||||
atomic.AddUint64(&pkt.queue.pendingVerdicts, ^uint64(0))
|
||||
select {
|
||||
case pkt.queue.verdictCompleted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if err := pkt.queue.nf.SetVerdictWithMark(pkt.ID, nfqueue.NfAccept, mark); err != nil {
|
||||
log.Warningf("nfqexp: failed to set verdict %s for %d (%s -> %s): %s", markToString(mark), pkt.ID, pkt.Info().Src, pkt.Info().Dst, err)
|
||||
if opErr, ok := err.(*netlink.OpError); ok {
|
||||
if err := pkt.queue.nf.SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil {
|
||||
// embedded interface is required to work-around some
|
||||
// dep-vendoring weirdness
|
||||
if opErr, ok := err.(interface {
|
||||
Timeout() bool
|
||||
Temporary() bool
|
||||
}); ok {
|
||||
if opErr.Timeout() || opErr.Temporary() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("nfqexp: failed to set verdict %s for %s (%s -> %s): %s", markToString(mark), pkt.ID(), pkt.Info().Src, pkt.Info().Dst, err)
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Tracef("nfqexp: marking packet %d (%s -> %s) on queue %d with %s after %s", pkt.ID, pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received))
|
||||
close(pkt.verdictSet)
|
||||
log.Tracef("nfqexp: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user