diff --git a/network/iphelper/tables.go b/network/iphelper/tables.go index f8f99802..9d59f7d7 100644 --- a/network/iphelper/tables.go +++ b/network/iphelper/tables.go @@ -95,8 +95,28 @@ const ( UDP uint8 = 17 ) +type learningBufSize struct { + sync.Mutex + + size int + usesLeft int + useFor int + start int + max int +} + +func newLearningBufSize(start, max, ttl int) *learningBufSize { + return &learningBufSize{ + size: start, + usesLeft: ttl, + useFor: ttl, + start: start, + max: max, + } +} + const ( - startBufSize = 4096 + startBufSize = 1024 // bufSizeUsageTTL defines how often a buffer size is used before it is // shrunk again. @@ -108,50 +128,51 @@ const ( ) var ( - bufSize = startBufSize - bufSizeUsesLeft = bufSizeUsageTTL - bufSizeLock sync.Mutex + tcp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL) + udp4BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL) + tcp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL) + udp6BufSize = newLearningBufSize(startBufSize, maxBufSize, bufSizeUsageTTL) ) -func getBufSize() int { - bufSizeLock.Lock() - defer bufSizeLock.Unlock() +func (lbf *learningBufSize) getBufSize() int { + lbf.Lock() + defer lbf.Unlock() // using bufSize - bufSizeUsesLeft-- + lbf.usesLeft-- // check if we want to reset - if bufSizeUsesLeft <= 0 { + if lbf.usesLeft <= 0 { // decrease - bufSize /= 2 + lbf.size /= 2 // not too little - if bufSize < startBufSize { - bufSize = startBufSize + if lbf.size < lbf.start { + lbf.size = lbf.start } - // reset counter - bufSizeUsesLeft = bufSizeUsageTTL + // reset TTL counter + lbf.usesLeft = lbf.useFor } - return bufSize + return lbf.size } -func increaseBufSize(minSize int) int { - bufSizeLock.Lock() - defer bufSizeLock.Unlock() +func (lbf *learningBufSize) increaseBufSize(minSize int) int { + lbf.Lock() + defer lbf.Unlock() // increase - bufSize *= 2 + lbf.size *= 2 // increase until we reach the minimum size - for bufSize < minSize { - bufSize *= 2 + for lbf.size < minSize { + lbf.size *= 2 } // not too much - if bufSize > maxBufSize { - bufSize = maxBufSize + if lbf.size > lbf.max { + lbf.size = lbf.max } - // reset - bufSizeUsesLeft = bufSizeUsageTTL + // reset TTL counter + lbf.usesLeft = lbf.useFor // return new bufSize - return bufSize + return lbf.size } // getTable returns the current connection state table of Windows of the given protocol and IP version. @@ -163,42 +184,54 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so } var afClass int + var lbf *learningBufSize switch ipVersion { case IPv4: afClass = windows.AF_INET + if protocol == TCP { + lbf = tcp4BufSize + } else { + lbf = udp4BufSize + } case IPv6: afClass = windows.AF_INET6 + if protocol == TCP { + lbf = tcp6BufSize + } else { + lbf = udp6BufSize + } default: return nil, nil, errors.New("invalid protocol") } // try max 5 times maxTries := 5 - bufSize := getBufSize() + usedBufSize := lbf.getBufSize() var buf []byte for i := 1; i <= maxTries; i++ { - buf = make([]byte, bufSize) + bufSizeParam := usedBufSize + buf = make([]byte, bufSizeParam) var r1 uintptr switch protocol { case TCP: r1, _, err = ipHelper.getExtendedTCPTable.Call( - uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable - uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize - 0, // _In_ BOOL bOrder - uintptr(afClass), // _In_ ULONG ulAf - iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass - 0, // _In_ ULONG Reserved + uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pTcpTable + uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize + 0, // _In_ BOOL bOrder + uintptr(afClass), // _In_ ULONG ulAf + iphelperTCPTableOwnerPIDAll, // _In_ TCP_TABLE_CLASS TableClass + 0, // _In_ ULONG Reserved ) case UDP: r1, _, err = ipHelper.getExtendedUDPTable.Call( - uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable, - uintptr(unsafe.Pointer(&bufSize)), // _Inout_ PDWORD pdwSize, - 0, // _In_ BOOL bOrder, - uintptr(afClass), // _In_ ULONG ulAf, - iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass, - 0, // _In_ ULONG Reserved + uintptr(unsafe.Pointer(&buf[0])), // _Out_ PVOID pUdpTable, + uintptr(unsafe.Pointer(&bufSizeParam)), // _Inout_ PDWORD pdwSize, + 0, // _In_ BOOL bOrder, + uintptr(afClass), // _In_ ULONG ulAf, + iphelperUDPTableOwnerPID, // _In_ UDP_TABLE_CLASS TableClass, + 0, // _In_ ULONG Reserved ) } @@ -206,13 +239,13 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so case winErrInsufficientBuffer: if i >= maxTries { return nil, nil, fmt.Errorf( - "insufficient buffer error (tried %d times): %d bytes required - [NT 0x%X] %s", - i, bufSize, r1, err, + "insufficient buffer error (tried %d times): provided %d bytes; required %d bytes - [NT 0x%X] %s", + i, usedBufSize, bufSizeParam, r1, err, ) } - // bufSize was modified by ipHelper.getExtended*Table to hold the + // bufSizeParam was modified by ipHelper.getExtended*Table to hold the // required buffer size. - bufSize = increaseBufSize(bufSize) + usedBufSize = lbf.increaseBufSize(bufSizeParam) case winErrInvalidParameter: return nil, nil, fmt.Errorf("invalid parameter: [NT 0x%X] %s", r1, err) case windows.NO_ERROR: