Add rust kext to the mono repo

This commit is contained in:
Vladimir Stoilov
2024-04-29 17:04:08 +03:00
parent 740ef1ad32
commit b0f664047b
98 changed files with 13811 additions and 84 deletions

View File

@@ -0,0 +1,537 @@
use crate::connection::{Connection, ConnectionV4, ConnectionV6, Direction, Verdict};
use crate::connection_map::Key;
use crate::device::{Device, Packet};
use crate::info;
use smoltcp::wire::{
IpAddress, IpProtocol, Ipv4Address, Ipv6Address, IPV4_HEADER_LEN, IPV6_HEADER_LEN,
};
use wdk::filter_engine::callout_data::CalloutData;
use wdk::filter_engine::layer::{
self, FieldsAleAuthConnectV4, FieldsAleAuthConnectV6, FieldsAleAuthRecvAcceptV4,
FieldsAleAuthRecvAcceptV6, ValueType,
};
use wdk::filter_engine::net_buffer::NetBufferList;
use wdk::filter_engine::packet::{Injector, TransportPacketList};
// ALE Layers
#[derive(Debug)]
#[allow(dead_code)]
struct AleLayerData {
is_ipv6: bool,
reauthorize: bool,
process_id: u64,
protocol: IpProtocol,
direction: Direction,
local_ip: IpAddress,
local_port: u16,
remote_ip: IpAddress,
remote_port: u16,
interface_index: u32,
sub_interface_index: u32,
}
impl AleLayerData {
fn as_key(&self) -> Key {
let mut local_port = 0;
let mut remote_port = 0;
match self.protocol {
IpProtocol::Tcp | IpProtocol::Udp => {
local_port = self.local_port;
remote_port = self.remote_port;
}
_ => {}
}
Key {
protocol: self.protocol,
local_address: self.local_ip,
local_port,
remote_address: self.remote_ip,
remote_port,
}
}
}
fn get_protocol(data: &CalloutData, index: usize) -> IpProtocol {
IpProtocol::from(data.get_value_u8(index))
}
fn get_ipv4_address(data: &CalloutData, index: usize) -> IpAddress {
IpAddress::Ipv4(Ipv4Address::from_bytes(
&data.get_value_u32(index).to_be_bytes(),
))
}
fn get_ipv6_address(data: &CalloutData, index: usize) -> IpAddress {
IpAddress::Ipv6(Ipv6Address::from_bytes(data.get_value_byte_array16(index)))
}
pub fn ale_layer_connect_v4(data: CalloutData) {
type Fields = FieldsAleAuthConnectV4;
let ale_data = AleLayerData {
is_ipv6: false,
reauthorize: data.is_reauthorize(Fields::Flags as usize),
process_id: data.get_process_id().unwrap_or(0),
protocol: get_protocol(&data, Fields::IpProtocol as usize),
direction: Direction::Outbound,
local_ip: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_ip: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
interface_index: 0,
sub_interface_index: 0,
};
ale_layer_auth(data, ale_data);
}
pub fn ale_layer_accept_v4(data: CalloutData) {
type Fields = FieldsAleAuthRecvAcceptV4;
let ale_data = AleLayerData {
is_ipv6: false,
reauthorize: data.is_reauthorize(Fields::Flags as usize),
process_id: data.get_process_id().unwrap_or(0),
protocol: get_protocol(&data, Fields::IpProtocol as usize),
direction: Direction::Inbound,
local_ip: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_ip: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
};
ale_layer_auth(data, ale_data);
}
pub fn ale_layer_connect_v6(data: CalloutData) {
type Fields = FieldsAleAuthConnectV6;
let ale_data = AleLayerData {
is_ipv6: true,
reauthorize: data.is_reauthorize(Fields::Flags as usize),
process_id: data.get_process_id().unwrap_or(0),
protocol: get_protocol(&data, Fields::IpProtocol as usize),
direction: Direction::Outbound,
local_ip: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_ip: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
};
ale_layer_auth(data, ale_data);
}
pub fn ale_layer_accept_v6(data: CalloutData) {
type Fields = FieldsAleAuthRecvAcceptV6;
let ale_data = AleLayerData {
is_ipv6: true,
reauthorize: data.is_reauthorize(Fields::Flags as usize),
process_id: data.get_process_id().unwrap_or(0),
protocol: get_protocol(&data, Fields::IpProtocol as usize),
direction: Direction::Inbound,
local_ip: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_ip: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
};
ale_layer_auth(data, ale_data);
}
fn ale_layer_auth(mut data: CalloutData, ale_data: AleLayerData) {
let Some(device) = crate::entry::get_device() else {
return;
};
match ale_data.protocol {
IpProtocol::Tcp | IpProtocol::Udp => {
// Only TCP and UDP make sense to be supported in the ALE layer.
// Everything else is not associated with a connection and will be handled in the packet layer.
}
_ => {
// Outbound: Will be handled by packet layer next.
// Inbound: Was already handled by the packet layer.
data.action_permit();
return;
}
}
let key = ale_data.as_key();
// Check if connection is already in cache.
let verdict = if ale_data.is_ipv6 {
device
.connection_cache
.read_connection_v6(&key, |conn| -> Option<Verdict> {
// Function is behind spin lock, just copy and return.
Some(conn.verdict)
})
} else {
device
.connection_cache
.read_connection_v4(&ale_data.as_key(), |conn| -> Option<Verdict> {
// Function is behind spin lock, just copy and return.
Some(conn.verdict)
})
};
// Connection already in cache.
if let Some(verdict) = verdict {
crate::dbg!("processing existing connection: {} {}", key, verdict);
match verdict {
// No verdict yet
Verdict::Undecided => {
crate::dbg!("saving packet: {}", key);
// Connection is already pended. Save packet and wait for verdict.
match save_packet(device, &mut data, &ale_data, false) {
Ok(packet) => {
let info = device.packet_cache.push(
(key, packet),
ale_data.process_id,
ale_data.direction,
true,
);
if let Some(info) = info {
let _ = device.event_queue.push(info);
}
}
Err(err) => {
crate::err!("failed to pend packet: {}", err);
}
};
data.block_and_absorb();
}
// There is a verdict
Verdict::PermanentAccept
| Verdict::Accept
| Verdict::RedirectNameServer
| Verdict::RedirectTunnel => {
// Continue to packet layer.
data.action_permit();
}
Verdict::PermanentBlock | Verdict::Undeterminable | Verdict::Failed => {
// Packet layer will not see this connection.
crate::dbg!("permanent block {}", key);
data.action_block();
}
Verdict::PermanentDrop => {
// Packet layer will not see this connection.
crate::dbg!("permanent drop {}", key);
data.block_and_absorb();
}
Verdict::Block => {
if let Direction::Outbound = ale_data.direction {
// Handled by packet layer.
data.action_permit();
} else {
// packet layer will still see the packets.
data.action_block();
}
}
Verdict::Drop => {
if let Direction::Outbound = ale_data.direction {
// Handled by packet layer.
data.action_permit();
} else {
// packet layer will still see the packets.
data.block_and_absorb();
}
}
}
} else {
crate::dbg!("pending connection: {} {}", key, ale_data.direction);
// Only first packet of a connection can be pended: reauthorize == false
let can_pend_connection = !ale_data.reauthorize;
match save_packet(device, &mut data, &ale_data, can_pend_connection) {
Ok(packet) => {
let info = device.packet_cache.push(
(key, packet),
ale_data.process_id,
ale_data.direction,
true,
);
if let Some(info) = info {
let _ = device.event_queue.push(info);
}
}
Err(err) => {
crate::err!("failed to pend packet: {}", err);
}
};
// Connection is not in cache, add it.
crate::dbg!("adding connection: {} PID: {}", key, ale_data.process_id);
if ale_data.is_ipv6 {
let conn =
ConnectionV6::from_key(&key, ale_data.process_id, ale_data.direction).unwrap();
device.connection_cache.add_connection_v6(conn);
} else {
let conn =
ConnectionV4::from_key(&key, ale_data.process_id, ale_data.direction).unwrap();
device.connection_cache.add_connection_v4(conn);
}
// Drop packet. It will be re-injected after Portmaster returns a verdict.
data.block_and_absorb();
}
}
fn save_packet(
device: &Device,
callout_data: &mut CalloutData,
ale_data: &AleLayerData,
pend: bool,
) -> Result<Packet, alloc::string::String> {
let mut packet_list = None;
let mut save_packet_list = true;
match ale_data.protocol {
IpProtocol::Tcp => {
if let Direction::Outbound = ale_data.direction {
// Only time a packet data is missing is during connect state of outbound TCP connection.
// Don't save packet list only if connection is outbound, reauthorize is false and the protocol is TCP.
save_packet_list = ale_data.reauthorize;
}
}
_ => {}
};
if save_packet_list {
packet_list = create_packet_list(device, callout_data, ale_data);
}
if pend && matches!(ale_data.protocol, IpProtocol::Tcp | IpProtocol::Udp) {
match callout_data.pend_operation(packet_list) {
Ok(classify_defer) => Ok(Packet::AleLayer(classify_defer)),
Err(err) => Err(alloc::format!("failed to defer connection: {}", err)),
}
} else {
Ok(Packet::AleLayer(callout_data.pend_filter_rest(packet_list)))
}
}
fn create_packet_list(
device: &Device,
callout_data: &mut CalloutData,
ale_data: &AleLayerData,
) -> Option<TransportPacketList> {
let mut nbl = NetBufferList::new(callout_data.get_layer_data() as _);
let mut inbound = false;
if let Direction::Inbound = ale_data.direction {
if ale_data.is_ipv6 {
nbl.retreat(IPV6_HEADER_LEN as u32, true);
} else {
nbl.retreat(IPV4_HEADER_LEN as u32, true);
}
inbound = true;
}
let address: &[u8] = match &ale_data.remote_ip {
IpAddress::Ipv4(address) => &address.0,
IpAddress::Ipv6(address) => &address.0,
};
if let Ok(clone) = nbl.clone(&device.network_allocator) {
return Some(Injector::from_ale_callout(
ale_data.is_ipv6,
callout_data,
clone,
address,
inbound,
ale_data.interface_index,
ale_data.sub_interface_index,
));
}
return None;
}
pub fn endpoint_closure_v4(data: CalloutData) {
type Fields = layer::FieldsAleEndpointClosureV4;
let Some(device) = crate::entry::get_device() else {
return;
};
let ip_address_type = data.get_value_type(Fields::IpLocalAddress as usize);
if let ValueType::FwpUint32 = ip_address_type {
let key = Key {
protocol: get_protocol(&data, Fields::IpProtocol as usize),
local_address: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_address: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
};
let conn = device.connection_cache.end_connection_v4(key);
if let Some(conn) = conn {
let info = protocol::info::connection_end_event_v4_info(
data.get_process_id().unwrap_or(0),
conn.get_direction() as u8,
u8::from(get_protocol(&data, Fields::IpProtocol as usize)),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
} else {
// Invalid ip address type. Just ignore the error.
// err!(
// device.logger,
// "unknown ipv4 address type: {:?}",
// ip_address_type
// );
}
}
pub fn endpoint_closure_v6(data: CalloutData) {
type Fields = layer::FieldsAleEndpointClosureV6;
let Some(device) = crate::entry::get_device() else {
return;
};
let local_ip_address_type = data.get_value_type(Fields::IpLocalAddress as usize);
let remote_ip_address_type = data.get_value_type(Fields::IpRemoteAddress as usize);
if let ValueType::FwpByteArray16Type = local_ip_address_type {
if let ValueType::FwpByteArray16Type = remote_ip_address_type {
let key = Key {
protocol: get_protocol(&data, Fields::IpProtocol as usize),
local_address: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
remote_address: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
};
let conn = device.connection_cache.end_connection_v6(key);
if let Some(conn) = conn {
let info = protocol::info::connection_end_event_v6_info(
data.get_process_id().unwrap_or(0),
conn.get_direction() as u8,
u8::from(get_protocol(&data, Fields::IpProtocol as usize)),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
}
}
}
pub fn ale_resource_monitor(data: CalloutData) {
let Some(device) = crate::entry::get_device() else {
return;
};
match data.layer {
layer::Layer::AleResourceAssignmentV4Discard => {
type Fields = layer::FieldsAleResourceAssignmentV4;
if let Some(conns) = device.connection_cache.end_all_on_port_v4((
get_protocol(&data, Fields::IpProtocol as usize),
data.get_value_u16(Fields::IpLocalPort as usize),
)) {
let process_id = data.get_process_id().unwrap_or(0);
info!(
"Port {}/{} Ipv4 assign request discarded pid={}",
data.get_value_u16(Fields::IpLocalPort as usize),
get_protocol(&data, Fields::IpProtocol as usize),
process_id,
);
for conn in conns {
let info = protocol::info::connection_end_event_v4_info(
process_id,
conn.get_direction() as u8,
data.get_value_u8(Fields::IpProtocol as usize),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
}
}
layer::Layer::AleResourceAssignmentV6Discard => {
type Fields = layer::FieldsAleResourceAssignmentV6;
if let Some(conns) = device.connection_cache.end_all_on_port_v6((
get_protocol(&data, Fields::IpProtocol as usize),
data.get_value_u16(Fields::IpLocalPort as usize),
)) {
let process_id = data.get_process_id().unwrap_or(0);
info!(
"Port {}/{} Ipv6 assign request discarded pid={}",
data.get_value_u16(Fields::IpLocalPort as usize),
get_protocol(&data, Fields::IpProtocol as usize),
process_id,
);
for conn in conns {
let info = protocol::info::connection_end_event_v6_info(
process_id,
conn.get_direction() as u8,
data.get_value_u8(Fields::IpProtocol as usize),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
}
}
layer::Layer::AleResourceReleaseV4 => {
type Fields = layer::FieldsAleResourceReleaseV4;
if let Some(conns) = device.connection_cache.end_all_on_port_v4((
get_protocol(&data, Fields::IpProtocol as usize),
data.get_value_u16(Fields::IpLocalPort as usize),
)) {
let process_id = data.get_process_id().unwrap_or(0);
info!(
"Port {}/{} released pid={}",
data.get_value_u16(Fields::IpLocalPort as usize),
get_protocol(&data, Fields::IpProtocol as usize),
process_id,
);
for conn in conns {
let info = protocol::info::connection_end_event_v4_info(
process_id,
conn.get_direction() as u8,
data.get_value_u8(Fields::IpProtocol as usize),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
}
}
layer::Layer::AleResourceReleaseV6 => {
type Fields = layer::FieldsAleResourceReleaseV6;
if let Some(conns) = device.connection_cache.end_all_on_port_v6((
get_protocol(&data, Fields::IpProtocol as usize),
data.get_value_u16(Fields::IpLocalPort as usize),
)) {
let process_id = data.get_process_id().unwrap_or(0);
info!(
"Port {}/{} released pid={}",
data.get_value_u16(Fields::IpLocalPort as usize),
get_protocol(&data, Fields::IpProtocol as usize),
process_id,
);
for conn in conns {
let info = protocol::info::connection_end_event_v6_info(
process_id,
conn.get_direction() as u8,
data.get_value_u8(Fields::IpProtocol as usize),
conn.local_address.0,
conn.remote_address.0,
conn.local_port,
conn.remote_port,
);
let _ = device.event_queue.push(info);
}
}
}
_ => {}
}
}

View File

@@ -0,0 +1,25 @@
use core::cell::RefCell;
use alloc::vec::Vec;
pub struct ArrayHolder(RefCell<Option<Vec<u8>>>);
unsafe impl Sync for ArrayHolder {}
impl ArrayHolder {
pub const fn default() -> Self {
Self(RefCell::new(None))
}
pub fn save(&self, data: &[u8]) {
if let Ok(mut opt) = self.0.try_borrow_mut() {
opt.replace(data.to_vec());
}
}
pub fn load(&self) -> Option<Vec<u8>> {
if let Ok(mut opt) = self.0.try_borrow_mut() {
return opt.take();
}
None
}
}

View File

@@ -0,0 +1,293 @@
use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info};
use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address};
use wdk::rw_spin_lock::RwSpinLock;
use crate::driver_hashmap::DeviceHashMap;
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
pub struct Key<Address>
where
Address: Eq + PartialEq,
{
pub local_ip: Address,
pub local_port: u16,
pub remote_ip: Address,
pub remote_port: u16,
}
struct Value {
received_bytes: usize,
transmitted_bytes: usize,
}
enum Direction {
Tx(usize),
Rx(usize),
}
pub struct Bandwidth {
stats_tcp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_tcp_v4_lock: RwSpinLock,
stats_tcp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_tcp_v6_lock: RwSpinLock,
stats_udp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_udp_v4_lock: RwSpinLock,
stats_udp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_udp_v6_lock: RwSpinLock,
}
impl Bandwidth {
pub fn new() -> Self {
Self {
stats_tcp_v4: DeviceHashMap::new(),
stats_tcp_v4_lock: RwSpinLock::default(),
stats_tcp_v6: DeviceHashMap::new(),
stats_tcp_v6_lock: RwSpinLock::default(),
stats_udp_v4: DeviceHashMap::new(),
stats_udp_v4_lock: RwSpinLock::default(),
stats_udp_v6: DeviceHashMap::new(),
stats_udp_v6_lock: RwSpinLock::default(),
}
}
pub fn get_all_updates_tcp_v4(&mut self) -> Option<Info> {
let stats_map;
{
let _guard = self.stats_tcp_v4_lock.write_lock();
if self.stats_tcp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
for (key, value) in stats_map.iter() {
values.push(BandwidthValueV4 {
local_ip: key.local_ip.0,
local_port: key.local_port,
remote_ip: key.remote_ip.0,
remote_port: key.remote_port,
transmitted_bytes: value.transmitted_bytes as u64,
received_bytes: value.received_bytes as u64,
});
}
Some(protocol::info::bandiwth_stats_array_v4(
u8::from(IpProtocol::Tcp),
values,
))
}
pub fn get_all_updates_tcp_v6(&mut self) -> Option<Info> {
let stats_map;
{
let _guard = self.stats_tcp_v6_lock.write_lock();
if self.stats_tcp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
for (key, value) in stats_map.iter() {
values.push(BandwidthValueV6 {
local_ip: key.local_ip.0,
local_port: key.local_port,
remote_ip: key.remote_ip.0,
remote_port: key.remote_port,
transmitted_bytes: value.transmitted_bytes as u64,
received_bytes: value.received_bytes as u64,
});
}
Some(protocol::info::bandiwth_stats_array_v6(
u8::from(IpProtocol::Tcp),
values,
))
}
pub fn get_all_updates_udp_v4(&mut self) -> Option<Info> {
let stats_map;
{
let _guard = self.stats_udp_v4_lock.write_lock();
if self.stats_udp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
for (key, value) in stats_map.iter() {
values.push(BandwidthValueV4 {
local_ip: key.local_ip.0,
local_port: key.local_port,
remote_ip: key.remote_ip.0,
remote_port: key.remote_port,
transmitted_bytes: value.transmitted_bytes as u64,
received_bytes: value.received_bytes as u64,
});
}
Some(protocol::info::bandiwth_stats_array_v4(
u8::from(IpProtocol::Udp),
values,
))
}
pub fn get_all_updates_udp_v6(&mut self) -> Option<Info> {
let stats_map;
{
let _guard = self.stats_udp_v6_lock.write_lock();
if self.stats_tcp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
for (key, value) in stats_map.iter() {
values.push(BandwidthValueV6 {
local_ip: key.local_ip.0,
local_port: key.local_port,
remote_ip: key.remote_ip.0,
remote_port: key.remote_port,
transmitted_bytes: value.transmitted_bytes as u64,
received_bytes: value.received_bytes as u64,
});
}
Some(protocol::info::bandiwth_stats_array_v6(
u8::from(IpProtocol::Udp),
values,
))
}
pub fn update_tcp_v4_tx(&mut self, key: Key<Ipv4Address>, tx_bytes: usize) {
Self::update(
&mut self.stats_tcp_v4,
&mut self.stats_tcp_v4_lock,
key,
Direction::Tx(tx_bytes),
);
}
pub fn update_tcp_v4_rx(&mut self, key: Key<Ipv4Address>, rx_bytes: usize) {
Self::update(
&mut self.stats_tcp_v4,
&mut self.stats_tcp_v4_lock,
key,
Direction::Rx(rx_bytes),
);
}
pub fn update_tcp_v6_tx(&mut self, key: Key<Ipv6Address>, tx_bytes: usize) {
Self::update(
&mut self.stats_tcp_v6,
&mut self.stats_tcp_v6_lock,
key,
Direction::Tx(tx_bytes),
);
}
pub fn update_tcp_v6_rx(&mut self, key: Key<Ipv6Address>, rx_bytes: usize) {
Self::update(
&mut self.stats_tcp_v6,
&mut self.stats_tcp_v6_lock,
key,
Direction::Rx(rx_bytes),
);
}
pub fn update_udp_v4_tx(&mut self, key: Key<Ipv4Address>, tx_bytes: usize) {
Self::update(
&mut self.stats_udp_v4,
&mut self.stats_udp_v4_lock,
key,
Direction::Tx(tx_bytes),
);
}
pub fn update_udp_v4_rx(&mut self, key: Key<Ipv4Address>, rx_bytes: usize) {
Self::update(
&mut self.stats_udp_v4,
&mut self.stats_udp_v4_lock,
key,
Direction::Rx(rx_bytes),
);
}
pub fn update_udp_v6_tx(&mut self, key: Key<Ipv6Address>, tx_bytes: usize) {
Self::update(
&mut self.stats_udp_v6,
&mut self.stats_udp_v6_lock,
key,
Direction::Tx(tx_bytes),
);
}
pub fn update_udp_v6_rx(&mut self, key: Key<Ipv6Address>, rx_bytes: usize) {
Self::update(
&mut self.stats_udp_v6,
&mut self.stats_udp_v6_lock,
key,
Direction::Rx(rx_bytes),
);
}
fn update<Address: Eq + PartialEq + core::hash::Hash>(
map: &mut DeviceHashMap<Key<Address>, Value>,
lock: &mut RwSpinLock,
key: Key<Address>,
bytes: Direction,
) {
let _guard = lock.write_lock();
if let Some(value) = map.get_mut(&key) {
match bytes {
Direction::Tx(bytes_count) => value.transmitted_bytes += bytes_count,
Direction::Rx(bytes_count) => value.received_bytes += bytes_count,
}
} else {
let mut received_bytes = 0;
let mut transmitted_bytes = 0;
match bytes {
Direction::Tx(bytes_count) => transmitted_bytes += bytes_count,
Direction::Rx(bytes_count) => received_bytes += bytes_count,
}
map.insert(
key,
Value {
received_bytes,
transmitted_bytes,
},
);
}
}
#[allow(dead_code)]
pub fn get_entries_count(&self) -> usize {
let mut size = 0;
{
let values = &self.stats_tcp_v4.values();
let _guard = self.stats_tcp_v4_lock.read_lock();
size += values.len();
}
{
let values = &self.stats_tcp_v6.values();
let _guard = self.stats_tcp_v6_lock.read_lock();
size += values.len();
}
{
let values = &self.stats_udp_v4.values();
let _guard = self.stats_udp_v4_lock.read_lock();
size += values.len();
}
{
let values = &self.stats_udp_v6.values();
let _guard = self.stats_udp_v6_lock.read_lock();
size += values.len();
}
return size;
}
}

View File

@@ -0,0 +1,185 @@
use alloc::vec::Vec;
use wdk::filter_engine::callout::FilterType;
use wdk::{
consts,
filter_engine::{callout::Callout, layer::Layer},
};
use crate::{ale_callouts, packet_callouts, stream_callouts};
pub fn get_callout_vec() -> Vec<Callout> {
alloc::vec![
// -----------------------------------------
// ALE Auth layers
Callout::new(
"AleLayerOutboundV4",
"ALE layer for outbound connection for ipv4",
0x58545073_f893_454c_bbea_a57bc964f46d,
Layer::AleAuthConnectV4,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::Resettable,
ale_callouts::ale_layer_connect_v4,
),
Callout::new(
"AleLayerInboundV4",
"ALE layer for inbound connections for ipv4",
0xc6021395_0724_4e2c_ae20_3dde51fc3c68,
Layer::AleAuthRecvAcceptV4,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::Resettable,
ale_callouts::ale_layer_accept_v4,
),
Callout::new(
"AleLayerOutboundV6",
"ALE layer for outbound connections for ipv6",
0x4bd2a080_2585_478d_977c_7f340c6bc3d4,
Layer::AleAuthConnectV6,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::Resettable,
ale_callouts::ale_layer_connect_v6,
),
Callout::new(
"AleLayerInboundV6",
"ALE layer for inbound connections for ipv6",
0xd24480da_38fa_4099_9383_b5c83b69e4f2,
Layer::AleAuthRecvAcceptV6,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::Resettable,
ale_callouts::ale_layer_accept_v6,
),
// -----------------------------------------
// ALE connection end layers
Callout::new(
"AleEndpointClosureV4",
"ALE layer for indicating closing of connection for ipv4",
0x58f02845_ace9_4455_ac80_8a84b86fe566,
Layer::AleEndpointClosureV4,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
ale_callouts::endpoint_closure_v4,
),
Callout::new(
"AleEndpointClosureV6",
"ALE layer for indicating closing of connection for ipv6",
0x2bc82359_9dc5_4315_9c93_c89467e283ce,
Layer::AleEndpointClosureV6,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
ale_callouts::endpoint_closure_v6,
),
// -----------------------------------------
// ALE resource assignment and release.
// Callout::new(
// "AleResourceAssignmentV4",
// "Ipv4 Port assignment monitoring",
// 0x6b9d1985_6f75_4d05_b9b5_1607e187906f,
// Layer::AleResourceAssignmentV4Discard,
// consts::FWP_ACTION_CALLOUT_INSPECTION,
// FilterType::NonResettable,
// ale_callouts::ale_resource_monitor,
// ),
Callout::new(
"AleResourceReleaseV4",
"Ipv4 Port release monitor",
0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7,
Layer::AleResourceReleaseV4,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
ale_callouts::ale_resource_monitor,
),
// Callout::new(
// "AleResourceAssignmentV6",
// "Ipv4 Port assignment monitor",
// 0xb0d02299_3d3e_437d_916a_f0e96a60cc18,
// Layer::AleResourceAssignmentV6Discard,
// consts::FWP_ACTION_CALLOUT_INSPECTION,
// FilterType::NonResettable,
// ale_callouts::ale_resource_monitor,
// ),
Callout::new(
"AleResourceReleaseV6",
"Ipv6 Port release monitor",
0x6cf36e04_e656_42c3_8cac_a1ce05328bd1,
Layer::AleResourceReleaseV6,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
ale_callouts::ale_resource_monitor,
),
// -----------------------------------------
// Stream layer
Callout::new(
"StreamLayerV4",
"Stream layer for ipv4",
0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780,
Layer::StreamV4,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
stream_callouts::stream_layer_tcp_v4,
),
Callout::new(
"StreamLayerV6",
"Stream layer for ipv6",
0x66c549b3_11e2_4b27_8f73_856e6fd82baa,
Layer::StreamV6,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
stream_callouts::stream_layer_tcp_v6,
),
Callout::new(
"DatagramDataLayerV4",
"DatagramData layer for ipv4",
0xe7eeeaba_168a_45bb_8747_e1a702feb2c5,
Layer::DatagramDataV4,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
stream_callouts::stream_layer_udp_v4,
),
Callout::new(
"DatagramDataLayerV6",
"DatagramData layer for ipv4",
0xb25862cd_f744_4452_b14a_d0c1e5a25b30,
Layer::DatagramDataV6,
consts::FWP_ACTION_CALLOUT_INSPECTION,
FilterType::NonResettable,
stream_callouts::stream_layer_udp_v6,
),
// -----------------------------------------
// Packet layers
Callout::new(
"IPPacketOutboundV4",
"IP packet outbound network layer callout for Ipv4",
0xf3183afe_dc35_49f1_8ea2_b16b5666dd36,
Layer::OutboundIppacketV4,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::NonResettable,
packet_callouts::ip_packet_layer_outbound_v4,
),
Callout::new(
"IPPacketInboundV4",
"IP packet inbound network layer callout for Ipv4",
0xf0369374_203d_4bf0_83d2_b2ad3cc17a50,
Layer::InboundIppacketV4,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::NonResettable,
packet_callouts::ip_packet_layer_inbound_v4,
),
Callout::new(
"IPPacketOutboundV6",
"IP packet outbound network layer callout for Ipv6",
0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a,
Layer::OutboundIppacketV6,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::NonResettable,
packet_callouts::ip_packet_layer_outbound_v6,
),
Callout::new(
"IPPacketInboundV6",
"IP packet inbound network layer callout for Ipv6",
0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0,
Layer::InboundIppacketV6,
consts::FWP_ACTION_CALLOUT_TERMINATING,
FilterType::NonResettable,
packet_callouts::ip_packet_layer_inbound_v6,
)
]
}

View File

@@ -0,0 +1,59 @@
#![allow(dead_code)]
use core::fmt::Display;
use num_derive::{FromPrimitive, ToPrimitive};
pub const ICMPV4_CODE_DESTINATION_UNREACHABLE: u32 = 3;
pub const ICMPV4_CODE_DU_PORT_UNREACHABLE: u32 = 3; // Destination Unreachable (Port unreachable) ;
pub const ICMPV4_CODE_DU_ADMINISTRATIVELY_PROHIBITED: u32 = 13; // Destination Unreachable (Communication Administratively Prohibited) ;
pub const ICMPV6_CODE_DESTINATION_UNREACHABLE: u32 = 1;
pub const ICMPV6_CODE_DU_PORT_UNREACHABLE: u32 = 4; // Destination Unreachable (Port unreachable) ;
enum Direction {
Outbound = 0,
Inbound = 1,
}
const SIOCTL_TYPE: u32 = 40000;
macro_rules! ctl_code {
($device_type:expr, $function:expr, $method:expr, $access:expr) => {
($device_type << 16) | ($access << 14) | ($function << 2) | $method
};
}
pub const METHOD_BUFFERED: u32 = 0;
pub const METHOD_IN_DIRECT: u32 = 1;
pub const METHOD_OUT_DIRECT: u32 = 2;
pub const METHOD_NEITHER: u32 = 3;
pub const FILE_READ_DATA: u32 = 0x0001; // file & pipe
pub const FILE_WRITE_DATA: u32 = 0x0002; // file & pipe
#[repr(u32)]
#[derive(FromPrimitive, ToPrimitive)]
pub enum ControlCode {
Version = ctl_code!(
SIOCTL_TYPE,
0x800,
METHOD_BUFFERED,
FILE_READ_DATA | FILE_WRITE_DATA
),
ShutdownRequest = ctl_code!(
SIOCTL_TYPE,
0x801,
METHOD_BUFFERED,
FILE_READ_DATA | FILE_WRITE_DATA
),
}
impl Display for ControlCode {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ControlCode::Version => _ = write!(f, "Version"),
ControlCode::ShutdownRequest => _ = write!(f, "Shutdown"),
};
return Ok(());
}
}

View File

@@ -0,0 +1,499 @@
use alloc::{
boxed::Box,
string::{String, ToString},
};
use core::{
fmt::{Debug, Display},
sync::atomic::{AtomicU64, Ordering},
};
use num_derive::FromPrimitive;
use smoltcp::wire::{IpAddress, IpProtocol, Ipv4Address, Ipv6Address};
use crate::connection_map::Key;
pub static PM_DNS_PORT: u16 = 53;
pub static PM_SPN_PORT: u16 = 717;
// Make sure this in sync with the Go version
#[derive(Copy, Clone, FromPrimitive)]
#[repr(u8)]
#[rustfmt::skip]
pub enum Verdict {
Undecided = 0, // Undecided is the default status of new connections.
Undeterminable = 1,
Accept = 2,
PermanentAccept = 3,
Block = 4,
PermanentBlock = 5,
Drop = 6,
PermanentDrop = 7,
RedirectNameServer = 8,
RedirectTunnel = 9,
Failed = 10,
}
impl Display for Verdict {
#[rustfmt::skip]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Verdict::Undecided => write!(f, "Undecided"),
Verdict::Undeterminable => write!(f, "Undeterminable"),
Verdict::Accept => write!(f, "Accept"),
Verdict::PermanentAccept => write!(f, "PermanentAccept"),
Verdict::Block => write!(f, "Block"),
Verdict::PermanentBlock => write!(f, "PermanentBlock"),
Verdict::Drop => write!(f, "Drop"),
Verdict::PermanentDrop => write!(f, "PermanentDrop"),
Verdict::RedirectNameServer => write!(f, "RedirectNameServer"),
Verdict::RedirectTunnel => write!(f, "RedirectTunnel"),
Verdict::Failed => write!(f, "Failed"),
}
}
}
#[allow(dead_code)]
impl Verdict {
/// Returns true if the verdict is a redirect.
pub fn is_redirect(&self) -> bool {
matches!(self, Verdict::RedirectNameServer | Verdict::RedirectTunnel)
}
/// Returns true if the verdict is a permanent verdict.
pub fn is_permanent(&self) -> bool {
matches!(
self,
Verdict::PermanentAccept
| Verdict::PermanentBlock
| Verdict::PermanentDrop
| Verdict::RedirectNameServer
| Verdict::RedirectTunnel
)
}
}
/// Direction of the connection.
#[derive(Copy, Clone, FromPrimitive)]
#[repr(u8)]
pub enum Direction {
Outbound = 0,
Inbound = 1,
}
impl Display for Direction {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Direction::Outbound => write!(f, "Outbound"),
Direction::Inbound => write!(f, "Inbound"),
}
}
}
impl Debug for Direction {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self)
}
}
#[derive(Clone)]
pub struct ConnectionExtra {
pub(crate) end_timestamp: u64,
pub(crate) direction: Direction,
}
pub trait Connection {
fn redirect_info(&self) -> Option<RedirectInfo> {
let redirect_address = if self.is_ipv6() {
IpAddress::Ipv6(Ipv6Address::LOOPBACK)
} else {
IpAddress::Ipv4(Ipv4Address::new(127, 0, 0, 1))
};
match self.get_verdict() {
Verdict::RedirectNameServer => Some(RedirectInfo {
local_address: self.get_local_address(),
remote_address: self.get_remote_address(),
remote_port: self.get_remote_port(),
redirect_port: PM_DNS_PORT,
unify: false,
redirect_address,
}),
Verdict::RedirectTunnel => Some(RedirectInfo {
local_address: self.get_local_address(),
remote_address: self.get_remote_address(),
remote_port: self.get_remote_port(),
redirect_port: PM_SPN_PORT,
unify: true,
redirect_address,
}),
_ => None,
}
}
/// Returns the key of the connection.
fn get_key(&self) -> Key {
Key {
protocol: self.get_protocol(),
local_address: self.get_local_address(),
local_port: self.get_local_port(),
remote_address: self.get_remote_address(),
remote_port: self.get_remote_port(),
}
}
/// Returns true if the connection is equal to the given key. The key is considered equal if the remote port and address are equal.
fn remote_equals(&self, key: &Key) -> bool;
/// Returns true if the connection is equal to the given key for redirecting. The key is considered equal if the remote port and address are equal.
fn redirect_equals(&self, key: &Key) -> bool;
/// Returns the protocol of the connection.
fn get_protocol(&self) -> IpProtocol;
/// Returns the verdict of the connection.
fn get_verdict(&self) -> Verdict;
/// Returns the local address of the connection.
fn get_local_address(&self) -> IpAddress;
/// Returns the local port of the connection.
fn get_local_port(&self) -> u16;
/// Returns the remote address of the connection.
fn get_remote_address(&self) -> IpAddress;
/// Returns the remote port of the connection.
fn get_remote_port(&self) -> u16;
/// Returns true if the connection is an IPv6 connection.
fn is_ipv6(&self) -> bool;
/// Returns the direction of the connection.
fn get_direction(&self) -> Direction;
// Returns the process id of the connection.
fn get_process_id(&self) -> u64;
/// Ends the connection.
fn end(&mut self, timestamp: u64);
/// Returns true if the connection has ended.
fn has_ended(&self) -> bool {
self.get_end_time() > 0
}
/// Returns the timestamp when the connection ended.
fn get_end_time(&self) -> u64;
/// Returns the timestamp when the connection was last accessed.
fn get_last_accessed_time(&self) -> u64;
/// Sets the timestamp when the connection was last accessed.
fn set_last_accessed_time(&self, timestamp: u64);
}
pub struct ConnectionV4 {
pub(crate) protocol: IpProtocol,
pub(crate) local_address: Ipv4Address,
pub(crate) local_port: u16,
pub(crate) remote_address: Ipv4Address,
pub(crate) remote_port: u16,
pub(crate) verdict: Verdict,
pub(crate) process_id: u64,
pub(crate) last_accessed_timestamp: AtomicU64,
pub(crate) extra: Box<ConnectionExtra>,
}
pub struct ConnectionV6 {
pub(crate) protocol: IpProtocol,
pub(crate) local_address: Ipv6Address,
pub(crate) local_port: u16,
pub(crate) remote_address: Ipv6Address,
pub(crate) remote_port: u16,
pub(crate) verdict: Verdict,
pub(crate) process_id: u64,
pub(crate) last_accessed_timestamp: AtomicU64,
pub(crate) extra: Box<ConnectionExtra>,
}
#[derive(Debug)]
pub struct RedirectInfo {
pub(crate) local_address: IpAddress,
pub(crate) remote_address: IpAddress,
pub(crate) remote_port: u16,
pub(crate) redirect_port: u16,
pub(crate) unify: bool,
pub(crate) redirect_address: IpAddress,
}
impl ConnectionV4 {
/// Creates a new ipv4 connection from the given key.
pub fn from_key(key: &Key, process_id: u64, direction: Direction) -> Result<Self, String> {
let IpAddress::Ipv4(local_address) = key.local_address else {
return Err("wrong ip address version".to_string());
};
let IpAddress::Ipv4(remote_address) = key.remote_address else {
return Err("wrong ip address version".to_string());
};
let timestamp = wdk::utils::get_system_timestamp_ms();
Ok(Self {
protocol: key.protocol,
local_address,
local_port: key.local_port,
remote_address,
remote_port: key.remote_port,
verdict: Verdict::Undecided,
process_id,
last_accessed_timestamp: AtomicU64::new(timestamp),
extra: Box::new(ConnectionExtra {
direction,
end_timestamp: 0,
}),
})
}
}
impl Connection for ConnectionV4 {
fn remote_equals(&self, key: &Key) -> bool {
if self.remote_port != key.remote_port {
return false;
}
if let IpAddress::Ipv4(remote_address) = &key.remote_address {
return self.remote_address.eq(remote_address);
}
false
}
fn get_key(&self) -> Key {
Key {
protocol: self.protocol,
local_address: IpAddress::Ipv4(self.local_address),
local_port: self.local_port,
remote_address: IpAddress::Ipv4(self.remote_address),
remote_port: self.remote_port,
}
}
fn redirect_equals(&self, key: &Key) -> bool {
match self.verdict {
Verdict::RedirectNameServer => {
if key.remote_port != PM_DNS_PORT {
return false;
}
match key.remote_address {
IpAddress::Ipv4(a) => a.is_loopback(),
IpAddress::Ipv6(_) => false,
}
}
Verdict::RedirectTunnel => {
if key.remote_port != PM_SPN_PORT {
return false;
}
key.local_address.eq(&key.remote_address)
}
_ => false,
}
}
fn get_protocol(&self) -> IpProtocol {
self.protocol
}
fn get_verdict(&self) -> Verdict {
self.verdict
}
fn get_local_address(&self) -> IpAddress {
IpAddress::Ipv4(self.local_address)
}
fn get_local_port(&self) -> u16 {
self.local_port
}
fn get_remote_address(&self) -> IpAddress {
IpAddress::Ipv4(self.remote_address)
}
fn get_remote_port(&self) -> u16 {
self.remote_port
}
fn is_ipv6(&self) -> bool {
false
}
fn get_process_id(&self) -> u64 {
self.process_id
}
fn get_direction(&self) -> Direction {
self.extra.direction
}
fn end(&mut self, timestamp: u64) {
self.extra.end_timestamp = timestamp;
}
fn get_end_time(&self) -> u64 {
self.extra.end_timestamp
}
fn get_last_accessed_time(&self) -> u64 {
self.last_accessed_timestamp.load(Ordering::Relaxed)
}
fn set_last_accessed_time(&self, timestamp: u64) {
self.last_accessed_timestamp
.store(timestamp, Ordering::Relaxed);
}
}
impl Clone for ConnectionV4 {
fn clone(&self) -> Self {
Self {
protocol: self.protocol,
local_address: self.local_address,
local_port: self.local_port,
remote_address: self.remote_address,
remote_port: self.remote_port,
verdict: self.verdict,
process_id: self.process_id,
last_accessed_timestamp: AtomicU64::new(
self.last_accessed_timestamp.load(Ordering::Relaxed),
),
extra: self.extra.clone(),
}
}
}
impl ConnectionV6 {
/// Creates a new ipv6 connection from the given key.
pub fn from_key(key: &Key, process_id: u64, direction: Direction) -> Result<Self, String> {
let IpAddress::Ipv6(local_address) = key.local_address else {
return Err("wrong ip address version".to_string());
};
let IpAddress::Ipv6(remote_address) = key.remote_address else {
return Err("wrong ip address version".to_string());
};
let timestamp = wdk::utils::get_system_timestamp_ms();
Ok(Self {
protocol: key.protocol,
local_address,
local_port: key.local_port,
remote_address,
remote_port: key.remote_port,
verdict: Verdict::Undecided,
process_id,
last_accessed_timestamp: AtomicU64::new(timestamp),
extra: Box::new(ConnectionExtra {
direction,
end_timestamp: 0,
}),
})
}
}
impl Connection for ConnectionV6 {
fn remote_equals(&self, key: &Key) -> bool {
if self.remote_port != key.remote_port {
return false;
}
if let IpAddress::Ipv6(remote_address) = &key.remote_address {
return self.remote_address.eq(remote_address);
}
false
}
fn get_key(&self) -> Key {
Key {
protocol: self.protocol,
local_address: IpAddress::Ipv6(self.local_address),
local_port: self.local_port,
remote_address: IpAddress::Ipv6(self.remote_address),
remote_port: self.remote_port,
}
}
fn redirect_equals(&self, key: &Key) -> bool {
match self.verdict {
Verdict::RedirectNameServer => {
if key.remote_port != PM_DNS_PORT {
return false;
}
match key.remote_address {
IpAddress::Ipv4(_) => false,
IpAddress::Ipv6(a) => a.is_loopback(),
}
}
Verdict::RedirectTunnel => {
if key.remote_port != PM_SPN_PORT {
return false;
}
key.local_address.eq(&key.remote_address)
}
_ => false,
}
}
fn get_protocol(&self) -> IpProtocol {
self.protocol
}
fn get_verdict(&self) -> Verdict {
self.verdict
}
fn get_local_address(&self) -> IpAddress {
IpAddress::Ipv6(self.local_address)
}
fn get_local_port(&self) -> u16 {
self.local_port
}
fn get_remote_address(&self) -> IpAddress {
IpAddress::Ipv6(self.remote_address)
}
fn get_remote_port(&self) -> u16 {
self.remote_port
}
fn is_ipv6(&self) -> bool {
true
}
fn get_process_id(&self) -> u64 {
self.process_id
}
fn get_direction(&self) -> Direction {
self.extra.direction
}
fn end(&mut self, timestamp: u64) {
self.extra.end_timestamp = timestamp;
}
fn get_end_time(&self) -> u64 {
self.extra.end_timestamp
}
fn get_last_accessed_time(&self) -> u64 {
self.last_accessed_timestamp.load(Ordering::Relaxed)
}
fn set_last_accessed_time(&self, timestamp: u64) {
self.last_accessed_timestamp
.store(timestamp, Ordering::Relaxed);
}
}
impl Clone for ConnectionV6 {
fn clone(&self) -> Self {
Self {
protocol: self.protocol,
local_address: self.local_address,
local_port: self.local_port,
remote_address: self.remote_address,
remote_port: self.remote_port,
verdict: self.verdict,
process_id: self.process_id,
last_accessed_timestamp: AtomicU64::new(
self.last_accessed_timestamp.load(Ordering::Relaxed),
),
extra: self.extra.clone(),
}
}
}

View File

@@ -0,0 +1,200 @@
use core::time::Duration;
use crate::{
connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict},
connection_map::{ConnectionMap, Key},
};
use alloc::{format, string::String, vec::Vec};
use smoltcp::wire::IpProtocol;
use wdk::rw_spin_lock::RwSpinLock;
pub struct ConnectionCache {
connections_v4: ConnectionMap<ConnectionV4>,
connections_v6: ConnectionMap<ConnectionV6>,
lock_v4: RwSpinLock,
lock_v6: RwSpinLock,
}
impl ConnectionCache {
pub fn new() -> Self {
Self {
connections_v4: ConnectionMap::new(),
connections_v6: ConnectionMap::new(),
lock_v4: RwSpinLock::default(),
lock_v6: RwSpinLock::default(),
}
}
pub fn add_connection_v4(&mut self, connection: ConnectionV4) {
let _guard = self.lock_v4.write_lock();
self.connections_v4.add(connection);
}
pub fn add_connection_v6(&mut self, connection: ConnectionV6) {
let _guard = self.lock_v6.write_lock();
self.connections_v6.add(connection);
}
pub fn update_connection(&mut self, key: Key, verdict: Verdict) -> Option<RedirectInfo> {
if key.is_ipv6() {
let _guard = self.lock_v6.write_lock();
if let Some(conn) = self.connections_v6.get_mut(&key) {
conn.verdict = verdict;
return conn.redirect_info();
}
} else {
let _guard = self.lock_v4.write_lock();
if let Some(conn) = self.connections_v4.get_mut(&key) {
conn.verdict = verdict;
return conn.redirect_info();
}
}
None
}
pub fn read_connection_v4<T>(
&self,
key: &Key,
process_connection: fn(&ConnectionV4) -> Option<T>,
) -> Option<T> {
let _guard = self.lock_v4.read_lock();
self.connections_v4.read(&key, process_connection)
}
pub fn read_connection_v6<T>(
&self,
key: &Key,
process_connection: fn(&ConnectionV6) -> Option<T>,
) -> Option<T> {
let _guard = self.lock_v6.read_lock();
self.connections_v6.read(&key, process_connection)
}
pub fn end_connection_v4(&mut self, key: Key) -> Option<ConnectionV4> {
let _guard = self.lock_v4.write_lock();
self.connections_v4.end(key)
}
pub fn end_connection_v6(&mut self, key: Key) -> Option<ConnectionV6> {
let _guard = self.lock_v6.write_lock();
self.connections_v6.end(key)
}
pub fn end_all_on_port_v4(&mut self, key: (IpProtocol, u16)) -> Option<Vec<ConnectionV4>> {
let _guard = self.lock_v4.write_lock();
self.connections_v4.end_all_on_port(key)
}
pub fn end_all_on_port_v6(&mut self, key: (IpProtocol, u16)) -> Option<Vec<ConnectionV6>> {
let _guard = self.lock_v6.write_lock();
self.connections_v6.end_all_on_port(key)
}
pub fn clean_ended_connections(&mut self) {
{
let _guard = self.lock_v4.write_lock();
self.connections_v4.clean_ended_connections();
}
{
let _guard = self.lock_v6.write_lock();
self.connections_v6.clean_ended_connections();
}
}
pub fn clear(&mut self) {
{
let _guard = self.lock_v4.write_lock();
self.connections_v4.clear();
}
{
let _guard = self.lock_v6.write_lock();
self.connections_v6.clear();
}
}
#[allow(dead_code)]
pub fn get_entries_count(&self) -> usize {
let mut size = 0;
{
let _guard = self.lock_v4.read_lock();
size += self.connections_v4.get_count();
}
{
let _guard = self.lock_v6.read_lock();
size += self.connections_v6.get_count();
}
return size;
}
#[allow(dead_code)]
pub fn get_full_cache_info(&self) -> String {
let mut info = String::new();
let now = wdk::utils::get_system_timestamp_ms();
{
let _guard = self.lock_v4.read_lock();
for ((protocol, port), connections) in self.connections_v4.iter() {
info.push_str(&format!("{} -> {}\n", protocol, port,));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}
{
let _guard = self.lock_v6.read_lock();
for ((protocol, port), connections) in self.connections_v6.iter() {
info.push_str(&format!("{} -> {} \n", protocol, port));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}
return info;
}
}

View File

@@ -0,0 +1,179 @@
use core::{fmt::Display, time::Duration};
use crate::connection::Connection;
use alloc::vec::Vec;
use hashbrown::HashMap;
use smoltcp::wire::{IpAddress, IpProtocol};
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
pub struct Key {
pub(crate) protocol: IpProtocol,
pub(crate) local_address: IpAddress,
pub(crate) local_port: u16,
pub(crate) remote_address: IpAddress,
pub(crate) remote_port: u16,
}
impl Display for Key {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"p: {} l: {}:{} r: {}:{}",
self.protocol,
self.local_address,
self.local_port,
self.remote_address,
self.remote_port
)
}
}
impl Key {
/// Returns the protocol and port as a tuple.
pub fn small(&self) -> (IpProtocol, u16) {
(self.protocol, self.local_port)
}
/// Returns true if the local address is an IPv4 address.
pub fn is_ipv6(&self) -> bool {
match self.local_address {
IpAddress::Ipv4(_) => false,
IpAddress::Ipv6(_) => true,
}
}
/// Returns true if the local address is a loopback address.
pub fn is_loopback(&self) -> bool {
match self.local_address {
IpAddress::Ipv4(ip) => ip.is_loopback(),
IpAddress::Ipv6(ip) => ip.is_loopback(),
}
}
/// Returns a new key with the local and remote addresses and ports reversed.
#[allow(dead_code)]
pub fn reverse(&self) -> Key {
Key {
protocol: self.protocol,
local_address: self.remote_address,
local_port: self.remote_port,
remote_address: self.local_address,
remote_port: self.local_port,
}
}
}
pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
impl<T: Connection + Clone> ConnectionMap<T> {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn add(&mut self, conn: T) {
let key = conn.get_key().small();
if let Some(connections) = self.0.get_mut(&key) {
connections.push(conn);
} else {
self.0.insert(key, alloc::vec![conn]);
}
}
pub fn get_mut(&mut self, key: &Key) -> Option<&mut T> {
if let Some(connections) = self.0.get_mut(&key.small()) {
for conn in connections {
if conn.remote_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return Some(conn);
}
}
}
None
}
pub fn read<C>(&self, key: &Key, read_connection: fn(&T) -> Option<C>) -> Option<C> {
if let Some(connections) = self.0.get(&key.small()) {
for conn in connections {
if conn.remote_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return read_connection(conn);
}
if conn.redirect_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return read_connection(conn);
}
}
}
None
}
pub fn end(&mut self, key: Key) -> Option<T> {
if let Some(connections) = self.0.get_mut(&key.small()) {
for (_, conn) in connections.iter_mut().enumerate() {
if conn.remote_equals(&key) {
conn.end(wdk::utils::get_system_timestamp_ms());
return Some(conn.clone());
}
}
}
return None;
}
pub fn end_all_on_port(&mut self, key: (IpProtocol, u16)) -> Option<Vec<T>> {
if let Some(connections) = self.0.get_mut(&key) {
let mut vec = Vec::with_capacity(connections.len());
for (_, conn) in connections.iter_mut().enumerate() {
if !conn.has_ended() {
conn.end(wdk::utils::get_system_timestamp_ms());
vec.push(conn.clone());
}
}
return Some(vec);
}
return None;
}
pub fn clear(&mut self) {
self.0.clear();
}
pub fn clean_ended_connections(&mut self) {
let now = wdk::utils::get_system_timestamp_ms();
const TEN_MINUETS: u64 = Duration::from_secs(60 * 10).as_millis() as u64;
let before_ten_minutes = now - TEN_MINUETS;
let before_one_minute = now - Duration::from_secs(60).as_millis() as u64;
for (_, connections) in self.0.iter_mut() {
connections.retain(|c| {
if c.has_ended() && c.get_end_time() < before_one_minute {
// Ended more than 1 minute ago
return false;
}
if c.get_last_accessed_time() < before_ten_minutes {
// Last active more than 10 minutes ago
return false;
}
// Keep
return true;
});
}
self.0.retain(|_, v| !v.is_empty());
}
#[allow(dead_code)]
pub fn get_count(&self) -> usize {
let mut count = 0;
for conn in self.0.values() {
count += conn.len();
}
return count;
}
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
self.0.iter()
}
}

View File

@@ -0,0 +1,329 @@
use alloc::string::String;
use num_traits::FromPrimitive;
use protocol::{command::CommandType, info::Info};
use smoltcp::wire::{IpAddress, IpProtocol, Ipv4Address, Ipv6Address};
use wdk::{
driver::Driver,
filter_engine::{
callout_data::ClassifyDefer,
net_buffer::{NetBufferList, NetworkAllocator},
packet::{InjectInfo, Injector},
FilterEngine,
},
ioqueue::{self, IOQueue},
irp_helpers::{ReadRequest, WriteRequest},
};
use crate::{
array_holder::ArrayHolder, bandwidth::Bandwidth, callouts, connection_cache::ConnectionCache,
connection_map::Key, dbg, err, id_cache::IdCache, logger, packet_util::Redirect,
};
pub enum Packet {
PacketLayer(NetBufferList, InjectInfo),
AleLayer(ClassifyDefer),
}
// Device Context
pub struct Device {
pub(crate) filter_engine: FilterEngine,
pub(crate) read_leftover: ArrayHolder,
pub(crate) event_queue: IOQueue<Info>,
pub(crate) packet_cache: IdCache,
pub(crate) connection_cache: ConnectionCache,
pub(crate) injector: Injector,
pub(crate) network_allocator: NetworkAllocator,
pub(crate) bandwidth_stats: Bandwidth,
}
impl Device {
/// Initialize all members of the device. Memory is handled by windows.
/// Make sure everything is initialized here.
pub fn new(driver: &Driver) -> Result<Self, String> {
let mut filter_engine =
match FilterEngine::new(driver, 0x7dab1057_8e2b_40c4_9b85_693e381d7896) {
Ok(fe) => fe,
Err(err) => return Err(alloc::format!("filter engine error: {}", err)),
};
if let Err(err) = filter_engine.commit(callouts::get_callout_vec()) {
return Err(err);
}
Ok(Self {
filter_engine,
read_leftover: ArrayHolder::default(),
event_queue: IOQueue::new(),
packet_cache: IdCache::new(),
connection_cache: ConnectionCache::new(),
injector: Injector::new(),
network_allocator: NetworkAllocator::new(),
bandwidth_stats: Bandwidth::new(),
})
}
/// Cleanup is called just before drop.
// pub fn cleanup(&mut self) {}
fn write_buffer(&mut self, read_request: &mut ReadRequest, info: Info) {
let bytes = info.as_bytes();
let count = read_request.write(bytes);
// Check if the full buffer was written.
if count < bytes.len() {
// Save the leftovers for later.
self.read_leftover.save(&bytes[count..]);
}
}
/// Called when handle. Read is called from user-space.
pub fn read(&mut self, read_request: &mut ReadRequest) {
if let Some(data) = self.read_leftover.load() {
// There are leftovers from previous request.
let count = read_request.write(&data);
// Check if full command was written.
if count < data.len() {
// Save the leftovers for later.
self.read_leftover.save(&data[count..]);
}
} else {
// Noting left from before. Wait for next commands.
match self.event_queue.wait_and_pop() {
Ok(info) => {
self.write_buffer(read_request, info);
}
Err(ioqueue::Status::Timeout) => {
// Timeout. This will only trigger if pop function is called with timeout.
read_request.timeout();
return;
}
Err(err) => {
// Queue failed. Send EOF, to notify user-space. Usually happens on rundown.
err!("failed to pop value: {}", err);
read_request.end_of_file();
return;
}
}
}
// Check if we have more space. InfoType + data_size == 5 bytes
while read_request.free_space() > 5 {
match self.event_queue.pop() {
Ok(info) => {
self.write_buffer(read_request, info);
}
Err(_) => {
break;
}
}
}
read_request.complete();
}
// Called when handle.Write is called from user-space.
pub fn write(&mut self, write_request: &mut WriteRequest) {
// Try parsing the command.
let mut buffer = write_request.get_buffer();
let command = protocol::command::parse_type(buffer);
let Some(command) = command else {
err!("Unknown command number: {}", buffer[0]);
return;
};
buffer = &buffer[1..];
let mut _classify_defer = None;
match command {
CommandType::Shutdown => {
wdk::dbg!("Shutdown command");
self.shutdown();
}
CommandType::Verdict => {
let verdict = protocol::command::parse_verdict(buffer);
wdk::dbg!("Verdict command");
// Received verdict decision for a specific connection.
if let Some((key, mut packet)) = self.packet_cache.pop_id(verdict.id) {
if let Some(verdict) = FromPrimitive::from_u8(verdict.verdict) {
dbg!("Verdict received {}: {}", key, verdict);
// Add verdict in the cache.
let redirect_info = self.connection_cache.update_connection(key, verdict);
// if verdict.is_permanent() {
// dbg!(self.logger, "resetting filters {}: {}", key, verdict);
// _ = self.filter_engine.reset_all_filters();
// }
match verdict {
crate::connection::Verdict::Accept
| crate::connection::Verdict::PermanentAccept => {
if let Err(err) = self.inject_packet(packet, false) {
err!("failed to inject packet: {}", err);
} else {
dbg!("packet injected: {}", key);
}
}
crate::connection::Verdict::RedirectNameServer
| crate::connection::Verdict::RedirectTunnel => {
if let Some(redirect_info) = redirect_info {
if let Err(err) = packet.redirect(redirect_info) {
err!("failed to redirect packet: {}", err);
}
if let Err(err) = self.inject_packet(packet, false) {
err!("failed to inject packet: {}", err);
}
}
}
_ => {
if let Err(err) = self.inject_packet(packet, true) {
err!("failed to inject packet: {}", err);
}
}
}
};
} else {
// Id was not in the packet cache.
let id = verdict.id;
err!("Verdict invalid id: {}", id);
}
}
CommandType::UpdateV4 => {
let update = protocol::command::parse_update_v4(buffer);
// Build the new action.
if let Some(verdict) = FromPrimitive::from_u8(update.verdict) {
// Update with new action.
dbg!("Verdict update received {:?}: {}", update, verdict);
_classify_defer = self.connection_cache.update_connection(
Key {
protocol: IpProtocol::from(update.protocol),
local_address: IpAddress::Ipv4(Ipv4Address::from_bytes(
&update.local_address,
)),
local_port: update.local_port,
remote_address: IpAddress::Ipv4(Ipv4Address::from_bytes(
&update.remote_address,
)),
remote_port: update.remote_port,
},
verdict,
);
} else {
err!("invalid verdict value: {}", update.verdict);
}
}
CommandType::UpdateV6 => {
let update = protocol::command::parse_update_v6(buffer);
// Build the new action.
if let Some(verdict) = FromPrimitive::from_u8(update.verdict) {
// Update with new action.
dbg!("Verdict update received {:?}: {}", update, verdict);
_classify_defer = self.connection_cache.update_connection(
Key {
protocol: IpProtocol::from(update.protocol),
local_address: IpAddress::Ipv6(Ipv6Address::from_bytes(
&update.local_address,
)),
local_port: update.local_port,
remote_address: IpAddress::Ipv6(Ipv6Address::from_bytes(
&update.remote_address,
)),
remote_port: update.remote_port,
},
verdict,
);
} else {
err!("invalid verdict value: {}", update.verdict);
}
}
CommandType::ClearCache => {
wdk::dbg!("ClearCache command");
self.connection_cache.clear();
if let Err(err) = self.filter_engine.reset_all_filters() {
err!("failed to reset filters: {}", err);
}
}
CommandType::GetLogs => {
wdk::dbg!("GetLogs command");
let lines_vec = logger::flush();
for line in lines_vec {
let _ = self.event_queue.push(line);
}
}
CommandType::GetBandwidthStats => {
wdk::dbg!("GetBandwidthStats command");
let stats = self.bandwidth_stats.get_all_updates_tcp_v4();
if let Some(stats) = stats {
_ = self.event_queue.push(stats);
}
let stats = self.bandwidth_stats.get_all_updates_tcp_v6();
if let Some(stats) = stats {
_ = self.event_queue.push(stats);
}
let stats = self.bandwidth_stats.get_all_updates_udp_v4();
if let Some(stats) = stats {
_ = self.event_queue.push(stats);
}
let stats = self.bandwidth_stats.get_all_updates_udp_v6();
if let Some(stats) = stats {
_ = self.event_queue.push(stats);
}
}
CommandType::PrintMemoryStats => {
// Getting the information takes a long time and interferes with the callouts causing the device to crash.
// TODO(vladimir): Make more optimized version
// info!(
// "Packet cache: {} entries",
// self.packet_cache.get_entries_count()
// );
// info!(
// "BandwidthStats cache: {} entries",
// self.bandwidth_stats.get_entries_count()
// );
// info!(
// "Connection cache: {} entries\n {}",
// self.connection_cache.get_entries_count(),
// self.connection_cache.get_full_cache_info()
// );
}
CommandType::CleanEndedConnections => {
wdk::dbg!("CleanEndedConnections command");
self.connection_cache.clean_ended_connections();
}
}
}
pub fn shutdown(&self) {
// End blocking operations from the queue. This will end pending read requests.
self.event_queue.rundown();
}
pub fn inject_packet(&mut self, packet: Packet, blocked: bool) -> Result<(), String> {
match packet {
Packet::PacketLayer(nbl, inject_info) => {
if !blocked {
self.injector.inject_net_buffer_list(nbl, inject_info)
} else {
Ok(())
}
}
Packet::AleLayer(defer) => {
let packet_list = defer.complete(&mut self.filter_engine)?;
if let Some(packet_list) = packet_list {
self.injector.inject_packet_list_transport(packet_list)?;
}
Ok(())
}
}
}
}
impl Drop for Device {
fn drop(&mut self) {
_ = logger::flush();
// dbg!("Device Context drop called.");
}
}

View File

@@ -0,0 +1,25 @@
use core::ops::{Deref, DerefMut};
use hashbrown::HashMap;
pub struct DeviceHashMap<Key, Value>(Option<HashMap<Key, Value>>);
impl<Key, Value> DeviceHashMap<Key, Value> {
pub fn new() -> Self {
Self(Some(HashMap::new()))
}
}
impl<Key, Value> Deref for DeviceHashMap<Key, Value> {
type Target = HashMap<Key, Value>;
fn deref(&self) -> &Self::Target {
self.0.as_ref().unwrap()
}
}
impl<Key, Value> DerefMut for DeviceHashMap<Key, Value> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().unwrap()
}
}

View File

@@ -0,0 +1,135 @@
use crate::common::ControlCode;
use crate::device;
use alloc::boxed::Box;
use num_traits::FromPrimitive;
use wdk::irp_helpers::{DeviceControlRequest, ReadRequest, WriteRequest};
use wdk::{err, info, interface};
use windows_sys::Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP};
use windows_sys::Win32::Foundation::{NTSTATUS, STATUS_SUCCESS};
static VERSION: [u8; 4] = include!("../../kext_interface/version.txt");
static mut DEVICE: *mut device::Device = core::ptr::null_mut();
pub fn get_device() -> Option<&'static mut device::Device> {
return unsafe { DEVICE.as_mut() };
}
// DriverEntry is the entry point of the driver (main function). Will be called when driver is loaded.
// Name should not be changed
#[export_name = "DriverEntry"]
pub extern "system" fn driver_entry(
driver_object: *mut windows_sys::Wdk::Foundation::DRIVER_OBJECT,
registry_path: *mut windows_sys::Win32::Foundation::UNICODE_STRING,
) -> windows_sys::Win32::Foundation::NTSTATUS {
info!("Starting initialization...");
// Initialize driver object.
let mut driver = match interface::init_driver_object(
driver_object,
registry_path,
"PortmasterKext",
core::ptr::null_mut(),
) {
Ok(driver) => driver,
Err(status) => {
err!("driver_entry: failed to initialize driver: {}", status);
return windows_sys::Win32::Foundation::STATUS_FAILED_DRIVER_ENTRY;
}
};
// Set driver functions.
driver.set_driver_unload(driver_unload);
driver.set_read_fn(driver_read);
driver.set_write_fn(driver_write);
driver.set_device_control_fn(device_control);
// Initialize device.
unsafe {
let device = match device::Device::new(&driver) {
Ok(device) => Box::new(device),
Err(err) => {
wdk::err!("filed to initialize device: {}", err);
return -1;
}
};
DEVICE = Box::into_raw(device);
}
STATUS_SUCCESS
}
// driver_unload function is called when service delete is called from user-space.
unsafe extern "system" fn driver_unload(_object: *const DRIVER_OBJECT) {
info!("Unloading complete");
unsafe {
if !DEVICE.is_null() {
_ = Box::from_raw(DEVICE);
}
}
}
// driver_read event triggered from user-space on file.Read.
unsafe extern "system" fn driver_read(
_device_object: &mut DEVICE_OBJECT,
irp: &mut IRP,
) -> NTSTATUS {
let mut read_request = ReadRequest::new(irp);
let Some(device) = get_device() else {
read_request.complete();
return read_request.get_status();
};
device.read(&mut read_request);
read_request.get_status()
}
/// driver_write event triggered from user-space on file.Write.
unsafe extern "system" fn driver_write(
_device_object: &mut DEVICE_OBJECT,
irp: &mut IRP,
) -> NTSTATUS {
let mut write_request = WriteRequest::new(irp);
let Some(device) = get_device() else {
write_request.complete();
return write_request.get_status();
};
device.write(&mut write_request);
write_request.mark_all_as_read();
write_request.complete();
write_request.get_status()
}
/// device_control event triggered from user-space on file.deviceIOControl.
unsafe extern "system" fn device_control(
_device_object: &mut DEVICE_OBJECT,
irp: &mut IRP,
) -> NTSTATUS {
let mut control_request = DeviceControlRequest::new(irp);
let Some(device) = get_device() else {
control_request.complete();
return control_request.get_status();
};
let Some(control_code): Option<ControlCode> =
FromPrimitive::from_u32(control_request.get_control_code())
else {
wdk::info!("Unknown IOCTL code: {}", control_request.get_control_code());
control_request.not_implemented();
return control_request.get_status();
};
wdk::info!("IOCTL: {}", control_code);
match control_code {
ControlCode::Version => {
control_request.write(&VERSION);
}
ControlCode::ShutdownRequest => device.shutdown(),
};
control_request.complete();
control_request.get_status()
}

View File

@@ -0,0 +1,131 @@
use alloc::collections::VecDeque;
use protocol::info::Info;
use smoltcp::wire::{IpAddress, IpProtocol};
use wdk::rw_spin_lock::RwSpinLock;
use crate::{connection::Direction, connection_map::Key, device::Packet};
struct Entry<T> {
value: T,
id: u64,
}
pub struct IdCache {
values: VecDeque<Entry<(Key, Packet)>>,
lock: RwSpinLock,
next_id: u64,
}
impl IdCache {
pub fn new() -> Self {
Self {
values: VecDeque::with_capacity(1000),
lock: RwSpinLock::default(),
next_id: 1, // 0 is invalid id
}
}
pub fn push(
&mut self,
value: (Key, Packet),
process_id: u64,
direction: Direction,
ale_layer: bool,
) -> Option<Info> {
let _guard = self.lock.write_lock();
let id = self.next_id;
let info = build_info(&value.0, id, process_id, direction, &value.1, ale_layer);
self.values.push_back(Entry { value, id });
self.next_id = self.next_id.wrapping_add(1); // Assuming this will not overflow.
return info;
}
pub fn pop_id(&mut self, id: u64) -> Option<(Key, Packet)> {
let _guard = self.lock.write_lock();
if let Ok(index) = self.values.binary_search_by_key(&id, |val| val.id) {
return Some(self.values.remove(index).unwrap().value);
}
None
}
#[allow(dead_code)]
pub fn get_entries_count(&self) -> usize {
let _guard = self.lock.read_lock();
return self.values.len();
}
}
fn get_payload<'a>(packet: &'a Packet) -> Option<&'a [u8]> {
match packet {
Packet::PacketLayer(nbl, _) => nbl.get_data(),
Packet::AleLayer(defer) => {
let p = match defer {
wdk::filter_engine::callout_data::ClassifyDefer::Initial(_, p) => p,
wdk::filter_engine::callout_data::ClassifyDefer::Reauthorization(_, p) => p,
};
if let Some(tpl) = p {
tpl.net_buffer_list_queue.get_data()
} else {
None
}
}
}
}
fn build_info(
key: &Key,
packet_id: u64,
process_id: u64,
direction: Direction,
packet: &Packet,
ale_layer: bool,
) -> Option<Info> {
let (local_port, remote_port) = match key.protocol {
IpProtocol::Tcp | IpProtocol::Udp => (key.local_port, key.remote_port),
_ => (0, 0),
};
let payload_layer = if ale_layer {
4 // Transport layer
} else {
3 // Network layer
};
let mut payload = &[][..];
if let Some(p) = get_payload(packet) {
payload = p;
}
match (key.local_address, key.remote_address) {
(IpAddress::Ipv6(local_ip), IpAddress::Ipv6(remote_ip)) if key.is_ipv6() => {
Some(protocol::info::connection_info_v6(
packet_id,
process_id,
direction as u8,
u8::from(key.protocol),
local_ip.0,
remote_ip.0,
local_port,
remote_port,
payload_layer,
payload,
))
}
(IpAddress::Ipv4(local_ip), IpAddress::Ipv4(remote_ip)) => {
Some(protocol::info::connection_info_v4(
packet_id,
process_id,
direction as u8,
u8::from(key.protocol),
local_ip.0,
remote_ip.0,
local_port,
remote_port,
payload_layer,
payload,
))
}
_ => None,
}
}

View File

@@ -0,0 +1,43 @@
#![cfg_attr(not(test), no_std)]
#![no_main]
#![allow(clippy::needless_return)]
extern crate alloc;
mod ale_callouts;
mod array_holder;
mod bandwidth;
mod callouts;
mod common;
mod connection;
mod connection_cache;
mod connection_map;
mod device;
mod driver_hashmap;
mod entry;
mod id_cache;
pub mod logger;
mod packet_callouts;
mod packet_util;
mod stream_callouts;
use wdk::allocator::WindowsAllocator;
#[cfg(not(test))]
use core::panic::PanicInfo;
// Declaration of the global memory allocator
#[global_allocator]
static HEAP: WindowsAllocator = WindowsAllocator {};
#[no_mangle]
pub extern "system" fn _DllMainCRTStartup() {}
#[cfg(not(test))]
#[panic_handler]
fn panic(info: &PanicInfo) -> ! {
use wdk::err;
err!("{}", info);
loop {}
}

View File

@@ -0,0 +1,114 @@
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::{
mem::MaybeUninit,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
};
use protocol::info::{Info, Severity};
#[cfg(not(debug_assertions))]
pub const LOG_LEVEL: u8 = Severity::Warning as u8;
#[cfg(debug_assertions)]
pub const LOG_LEVEL: u8 = Severity::Trace as u8;
pub const MAX_LOG_LINE_SIZE: usize = 150;
static mut LOG_LINES: [AtomicPtr<Info>; 1024] = unsafe { MaybeUninit::zeroed().assume_init() };
static START_INDEX: AtomicUsize = unsafe { MaybeUninit::zeroed().assume_init() };
static END_INDEX: AtomicUsize = unsafe { MaybeUninit::zeroed().assume_init() };
pub fn add_line(log_line: Info) {
let mut index = END_INDEX.fetch_add(1, Ordering::Acquire);
unsafe {
index %= LOG_LINES.len();
let ptr = &mut LOG_LINES[index];
let line = Box::new(log_line);
let old = ptr.swap(Box::into_raw(line), Ordering::SeqCst);
if !old.is_null() {
_ = Box::from_raw(old);
}
}
}
pub fn flush() -> Vec<Info> {
let mut vec = Vec::new();
let end_index = END_INDEX.load(Ordering::Acquire);
let start_index = START_INDEX.load(Ordering::Acquire);
if end_index <= start_index {
return vec;
}
unsafe {
let count = end_index - start_index;
for i in start_index..start_index + count {
let index = i % LOG_LINES.len();
let ptr = LOG_LINES[index].swap(core::ptr::null_mut(), Ordering::SeqCst);
if !ptr.is_null() {
vec.push(*Box::from_raw(ptr));
}
}
}
START_INDEX.store(end_index, Ordering::Release);
vec
}
#[macro_export]
macro_rules! log_internal {
($log_line:expr, $($arg:tt)*) => ({
use core::fmt::Write;
_ = write!($log_line, "{}:{} ", file!(), line!());
_ = write!($log_line, $($arg)*);
$crate::logger::add_line($log_line);
});
}
#[macro_export]
macro_rules! crit {
($($arg:tt)*) => ({
if protocol::info::Severity::Critical as u8 >= $crate::logger::LOG_LEVEL {
let message = alloc::format!($($arg)*);
$crate::logger::add_line(protocol::info::Severity::Critical, alloc::format!("{}:{} ", file!(), line!()), message)
}
});
}
#[macro_export]
macro_rules! err {
($($arg:tt)*) => ({
if protocol::info::Severity::Error as u8 >= $crate::logger::LOG_LEVEL {
let mut log_line = protocol::info::log_line(protocol::info::Severity::Error, $crate::logger::MAX_LOG_LINE_SIZE);
$crate::log_internal!(log_line, $($arg)*);
}
});
}
#[macro_export]
macro_rules! warn {
($($arg:tt)*) => ({
if protocol::info::Severity::Warning as u8 >= $crate::logger::LOG_LEVEL {
let mut log_line = protocol::info::log_line(protocol::info::Severity::Warning, $crate::logger::MAX_LOG_LINE_SIZE);
$crate::log_internal!(log_line, $($arg)*);
}
});
}
#[macro_export]
macro_rules! dbg {
($($arg:tt)*) => ({
if protocol::info::Severity::Debug as u8 >= $crate::logger::LOG_LEVEL {
let mut log_line = protocol::info::log_line(protocol::info::Severity::Debug, $crate::logger::MAX_LOG_LINE_SIZE);
$crate::log_internal!(log_line, $($arg)*);
}
});
}
#[macro_export]
macro_rules! info {
($($arg:tt)*) => ({
if protocol::info::Severity::Info as u8 >= $crate::logger::LOG_LEVEL {
let mut log_line = protocol::info::log_line(protocol::info::Severity::Info, $crate::logger::MAX_LOG_LINE_SIZE);
$crate::log_internal!(log_line, $($arg)*);
}
});
}

View File

@@ -0,0 +1,298 @@
use alloc::string::String;
use smoltcp::wire::{IPV4_HEADER_LEN, IPV6_HEADER_LEN};
use wdk::filter_engine::callout_data::CalloutData;
use wdk::filter_engine::layer;
use wdk::filter_engine::net_buffer::{NetBufferList, NetBufferListIter};
use wdk::filter_engine::packet::InjectInfo;
use crate::connection::{
Connection, ConnectionV4, ConnectionV6, Direction, RedirectInfo, Verdict, PM_DNS_PORT,
PM_SPN_PORT,
};
use crate::connection_cache::ConnectionCache;
use crate::connection_map::Key;
use crate::device::{Device, Packet};
use crate::packet_util::{get_key_from_nbl_v4, get_key_from_nbl_v6, Redirect};
use crate::{err, warn};
// IP packet layers
pub fn ip_packet_layer_outbound_v4(data: CalloutData) {
type Fields = layer::FieldsOutboundIppacketV4;
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
ip_packet_layer(
data,
false,
Direction::Outbound,
interface_index,
sub_interface_index,
);
}
pub fn ip_packet_layer_inbound_v4(data: CalloutData) {
type Fields = layer::FieldsInboundIppacketV4;
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
ip_packet_layer(
data,
false,
Direction::Inbound,
interface_index,
sub_interface_index,
);
}
pub fn ip_packet_layer_outbound_v6(data: CalloutData) {
type Fields = layer::FieldsOutboundIppacketV6;
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
ip_packet_layer(
data,
true,
Direction::Outbound,
interface_index,
sub_interface_index,
);
}
pub fn ip_packet_layer_inbound_v6(data: CalloutData) {
type Fields = layer::FieldsInboundIppacketV6;
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
ip_packet_layer(
data,
true,
Direction::Inbound,
interface_index,
sub_interface_index,
);
}
struct ConnectionInfo {
verdict: Verdict,
process_id: u64,
redirect_info: Option<RedirectInfo>,
}
impl ConnectionInfo {
fn from_connection<T: Connection>(conn: &T) -> Self {
ConnectionInfo {
verdict: conn.get_verdict(),
process_id: conn.get_process_id(),
redirect_info: conn.redirect_info(),
}
}
}
fn fast_track_pm_packets(key: &Key, direction: Direction) -> bool {
match direction {
Direction::Outbound => {
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
return key.local_address == key.remote_address;
}
}
Direction::Inbound => {
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
return key.local_address == key.remote_address;
}
}
}
return false;
}
fn ip_packet_layer(
mut data: CalloutData,
ipv6: bool,
direction: Direction,
interface_index: u32,
sub_interface_index: u32,
) {
let Some(device) = crate::entry::get_device() else {
return;
};
if device
.injector
.was_network_packet_injected_by_self(data.get_layer_data() as _, ipv6)
{
data.action_permit();
return;
}
for mut nbl in NetBufferListIter::new(data.get_layer_data() as _) {
if let Direction::Inbound = direction {
// The header is not part of the NBL for incoming packets. Move the beginning of the buffer back so we get access to it.
// The NBL will auto advance after it loses scope.
if ipv6 {
nbl.retreat(IPV6_HEADER_LEN as u32, true);
} else {
nbl.retreat(IPV4_HEADER_LEN as u32, true);
}
}
// Get key from packet.
let key = match if ipv6 {
get_key_from_nbl_v6(&nbl, direction)
} else {
get_key_from_nbl_v4(&nbl, direction)
} {
Ok(key) => key,
Err(err) => {
warn!("failed to get key from nbl: {}", err);
return;
}
};
if fast_track_pm_packets(&key, direction) {
data.action_permit();
return;
}
let mut is_tmp_verdict = false;
let mut process_id = 0;
if matches!(
key.protocol,
smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp
) {
if let Some(mut conn_info) =
get_connection_info(&mut device.connection_cache, &key, ipv6)
{
process_id = conn_info.process_id;
// Check if there is action for this connection.
match conn_info.verdict {
Verdict::Undecided | Verdict::Accept | Verdict::Block | Verdict::Drop => {
is_tmp_verdict = true
}
Verdict::PermanentAccept => data.action_permit(),
Verdict::PermanentBlock => data.action_block(),
Verdict::Undeterminable | Verdict::PermanentDrop | Verdict::Failed => {
data.block_and_absorb()
}
Verdict::RedirectNameServer | Verdict::RedirectTunnel => {
if let Some(redirect_info) = conn_info.redirect_info.take() {
match clone_packet(
device,
nbl,
direction,
ipv6,
key.is_loopback(),
interface_index,
sub_interface_index,
) {
Ok(mut packet) => {
let _ = packet.redirect(redirect_info);
if let Err(err) = device.inject_packet(packet, false) {
err!("failed to inject packet: {}", err);
}
}
Err(err) => err!("failed to clone packet: {}", err),
}
}
// This will block the original packet. Even if injection failed.
data.block_and_absorb();
continue;
}
}
} else {
// TCP and UDP always need to go through ALE layer first.
if matches!(direction, Direction::Inbound) {
// If it's an inbound packet and the connection is not found, we need to continue to ALE layer
data.action_permit();
return;
} else {
// This happens sometimes. Leave the decision for portmaster. TODO(vladimir): Find out why.
err!("Invalid state for: {}", key);
is_tmp_verdict = true;
}
}
} else {
// Every other protocol treat as a tmp verdict.
is_tmp_verdict = true;
}
// Clone packet and send to Portmaster if it's a temporary verdict.
if is_tmp_verdict {
let packet = match clone_packet(
device,
nbl,
direction,
ipv6,
key.is_loopback(),
interface_index,
sub_interface_index,
) {
Ok(p) => p,
Err(err) => {
err!("failed to clone packet: {}", err);
return;
}
};
let info = device
.packet_cache
.push((key, packet), process_id, direction, false);
// Send to Portmaster
if let Some(info) = info {
let _ = device.event_queue.push(info);
}
data.block_and_absorb();
}
}
}
fn clone_packet(
device: &mut Device,
nbl: NetBufferList,
direction: Direction,
ipv6: bool,
loopback: bool,
interface_index: u32,
sub_interface_index: u32,
) -> Result<Packet, String> {
let clone = nbl.clone(&device.network_allocator)?;
let inbound = match direction {
Direction::Outbound => false,
Direction::Inbound => true,
};
Ok(Packet::PacketLayer(
clone,
InjectInfo {
ipv6,
inbound,
loopback,
interface_index,
sub_interface_index,
},
))
}
fn get_connection_info(
connection_cache: &mut ConnectionCache,
key: &Key,
ipv6: bool,
) -> Option<ConnectionInfo> {
if ipv6 {
let conn_info = connection_cache.read_connection_v6(
&key,
|conn: &ConnectionV6| -> Option<ConnectionInfo> {
// Function is is behind spin lock. Just copy and return.
Some(ConnectionInfo::from_connection(conn))
},
);
return conn_info;
} else {
let conn_info = connection_cache.read_connection_v4(
&key,
|conn: &ConnectionV4| -> Option<ConnectionInfo> {
// Function is is behind spin lock. Just copy and return.
Some(ConnectionInfo::from_connection(conn))
},
);
return conn_info;
}
}

View File

@@ -0,0 +1,399 @@
use alloc::string::{String, ToString};
use smoltcp::wire::{
IpAddress, IpProtocol, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet, TcpPacket, UdpPacket,
};
use wdk::filter_engine::net_buffer::NetBufferList;
use crate::connection_map::Key;
use crate::device::Packet;
use crate::{
connection::{Direction, RedirectInfo},
dbg, err,
};
/// `Redirect` is a trait that defines a method for redirecting network packets.
///
/// This trait is used to implement different strategies for redirecting packets,
/// depending on the specific requirements of the application.
pub trait Redirect {
/// Redirects a network packet based on the provided `RedirectInfo`.
///
/// # Arguments
///
/// * `redirect_info` - A struct containing information about how to redirect the packet.
///
/// # Returns
///
/// * `Ok(())` if the packet was successfully redirected.
/// * `Err(String)` if there was an error redirecting the packet.
fn redirect(&mut self, redirect_info: RedirectInfo) -> Result<(), String>;
}
impl Redirect for Packet {
fn redirect(&mut self, redirect_info: RedirectInfo) -> Result<(), String> {
if let Packet::PacketLayer(nbl, inject_info) = self {
let Some(data) = nbl.get_data_mut() else {
return Err("trying to redirect immutable NBL".to_string());
};
if inject_info.inbound {
redirect_inbound_packet(
data,
redirect_info.local_address,
redirect_info.remote_address,
redirect_info.remote_port,
)
} else {
redirect_outbound_packet(
data,
redirect_info.redirect_address,
redirect_info.redirect_port,
redirect_info.unify,
)
}
return Ok(());
}
// return Err("can't redirect from non packet layer".to_string());
return Ok(());
}
}
/// Redirects an outbound packet to a specified remote address and port.
///
/// # Arguments
///
/// * `packet` - A mutable reference to the packet data.
/// * `remote_address` - The IP address to redirect the packet to.
/// * `remote_port` - The port to redirect the packet to.
/// * `unify` - If true, the source and destination addresses of the packet will be set to the same value.
///
/// This function modifies the packet in-place to change its destination address and port.
/// It also updates the checksums for the IP and transport layer headers.
/// If the `unify` parameter is true, it sets the source and destination addresses to be the same.
/// If the remote address is a loopback address, it sets the source address to the loopback address.
fn redirect_outbound_packet(
packet: &mut [u8],
remote_address: IpAddress,
remote_port: u16,
unify: bool,
) {
match remote_address {
IpAddress::Ipv4(remote_address) => {
if let Ok(mut ip_packet) = Ipv4Packet::new_checked(packet) {
if unify {
ip_packet.set_dst_addr(ip_packet.src_addr());
} else {
ip_packet.set_dst_addr(remote_address);
if remote_address.is_loopback() {
ip_packet.set_src_addr(Ipv4Address::new(127, 0, 0, 1));
}
}
ip_packet.fill_checksum();
let src_addr = ip_packet.src_addr();
let dst_addr = ip_packet.dst_addr();
if ip_packet.next_header() == IpProtocol::Udp {
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
udp_packet.set_dst_port(remote_port);
udp_packet
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
}
}
if ip_packet.next_header() == IpProtocol::Tcp {
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
tcp_packet.set_dst_port(remote_port);
tcp_packet
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
}
}
}
}
IpAddress::Ipv6(remote_address) => {
if let Ok(mut ip_packet) = Ipv6Packet::new_checked(packet) {
ip_packet.set_dst_addr(remote_address);
if unify {
ip_packet.set_dst_addr(ip_packet.src_addr());
} else {
ip_packet.set_dst_addr(remote_address);
if remote_address.is_loopback() {
ip_packet.set_src_addr(Ipv6Address::LOOPBACK);
}
}
let src_addr = ip_packet.src_addr();
let dst_addr = ip_packet.dst_addr();
if ip_packet.next_header() == IpProtocol::Udp {
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
udp_packet.set_dst_port(remote_port);
udp_packet
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
}
}
if ip_packet.next_header() == IpProtocol::Tcp {
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
tcp_packet.set_dst_port(remote_port);
tcp_packet
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
}
}
}
}
}
}
/// Redirects an inbound packet to a local address.
///
/// This function takes a mutable reference to a packet and modifies it in place.
/// It changes the destination address to the provided local address and the source address
/// to the original remote address. It also sets the source port to the original remote port.
/// The function handles both IPv4 and IPv6 addresses.
///
/// # Arguments
///
/// * `packet` - A mutable reference to the packet data.
/// * `local_address` - The local IP address to redirect the packet to.
/// * `original_remote_address` - The original remote IP address of the packet.
/// * `original_remote_port` - The original remote port of the packet.
///
fn redirect_inbound_packet(
packet: &mut [u8],
local_address: IpAddress,
original_remote_address: IpAddress,
original_remote_port: u16,
) {
match local_address {
IpAddress::Ipv4(local_address) => {
let IpAddress::Ipv4(original_remote_address) = original_remote_address else {
return;
};
if let Ok(mut ip_packet) = Ipv4Packet::new_checked(packet) {
ip_packet.set_dst_addr(local_address);
ip_packet.set_src_addr(original_remote_address);
ip_packet.fill_checksum();
let src_addr = ip_packet.src_addr();
let dst_addr = ip_packet.dst_addr();
if ip_packet.next_header() == IpProtocol::Udp {
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
udp_packet.set_src_port(original_remote_port);
udp_packet
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
}
}
if ip_packet.next_header() == IpProtocol::Tcp {
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
tcp_packet.set_src_port(original_remote_port);
tcp_packet
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
}
}
}
}
IpAddress::Ipv6(local_address) => {
if let Ok(mut ip_packet) = Ipv6Packet::new_checked(packet) {
let IpAddress::Ipv6(original_remote_address) = original_remote_address else {
return;
};
ip_packet.set_dst_addr(local_address);
ip_packet.set_src_addr(original_remote_address);
let src_addr = ip_packet.src_addr();
let dst_addr = ip_packet.dst_addr();
if ip_packet.next_header() == IpProtocol::Udp {
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
udp_packet.set_src_port(original_remote_port);
udp_packet
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
}
}
if ip_packet.next_header() == IpProtocol::Tcp {
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
tcp_packet.set_src_port(original_remote_port);
tcp_packet
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
}
}
}
}
}
}
#[allow(dead_code)]
fn print_packet(packet: &[u8]) {
if let Ok(ip_packet) = Ipv4Packet::new_checked(packet) {
if ip_packet.next_header() == IpProtocol::Udp {
if let Ok(udp_packet) = UdpPacket::new_checked(ip_packet.payload()) {
dbg!("packet {} {}", ip_packet, udp_packet);
}
}
if ip_packet.next_header() == IpProtocol::Tcp {
if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
dbg!("packet {} {}", ip_packet, tcp_packet);
}
}
} else {
err!("failed to print packet: invalid ip header: {:?}", packet);
}
}
/// This function extracts a key from a given IPv4 network buffer list (NBL).
/// The key contains the protocol, local and remote addresses and ports.
///
/// # Arguments
///
/// * `nbl` - A reference to the network buffer list from which the key will be extracted.
/// * `direction` - The direction of the packet (Inbound or Outbound).
///
/// # Returns
///
/// * `Ok(Key)` - A key containing the protocol, local and remote addresses and ports.
/// * `Err(String)` - An error message if the function fails to get net_buffer data.
const HEADERS_LEN: usize = smoltcp::wire::IPV4_HEADER_LEN + smoltcp::wire::TCP_HEADER_LEN;
fn get_ports(packet: &[u8], protocol: smoltcp::wire::IpProtocol) -> (u16, u16) {
match protocol {
smoltcp::wire::IpProtocol::Tcp => {
let tcp_packet = TcpPacket::new_unchecked(packet);
(tcp_packet.src_port(), tcp_packet.dst_port())
}
smoltcp::wire::IpProtocol::Udp => {
let udp_packet = UdpPacket::new_unchecked(packet);
(udp_packet.src_port(), udp_packet.dst_port())
}
_ => (0, 0), // No ports for other protocols
}
}
pub fn get_key_from_nbl_v4(nbl: &NetBufferList, direction: Direction) -> Result<Key, String> {
// Get bytes
let mut headers = [0; HEADERS_LEN];
if nbl.read_bytes(&mut headers).is_err() {
return Err("failed to get net_buffer data".to_string());
}
// Parse packet
let ip_packet = Ipv4Packet::new_unchecked(&headers);
let (src_port, dst_port) = get_ports(
&headers[smoltcp::wire::IPV4_HEADER_LEN..],
ip_packet.next_header(),
);
// Build key
match direction {
Direction::Outbound => Ok(Key {
protocol: ip_packet.next_header(),
local_address: IpAddress::Ipv4(ip_packet.src_addr()),
local_port: src_port,
remote_address: IpAddress::Ipv4(ip_packet.dst_addr()),
remote_port: dst_port,
}),
Direction::Inbound => Ok(Key {
protocol: ip_packet.next_header(),
local_address: IpAddress::Ipv4(ip_packet.dst_addr()),
local_port: dst_port,
remote_address: IpAddress::Ipv4(ip_packet.src_addr()),
remote_port: src_port,
}),
}
}
/// This function extracts a key from a given IPv6 network buffer list (NBL).
/// The key contains the protocol, local and remote addresses and ports.
///
/// # Arguments
///
/// * `nbl` - A reference to the network buffer list from which the key will be extracted.
/// * `direction` - The direction of the packet (Inbound or Outbound).
///
/// # Returns
///
/// * `Ok(Key)` - A key containing the protocol, local and remote addresses and ports.
/// * `Err(String)` - An error message if the function fails to get net_buffer data.
pub fn get_key_from_nbl_v6(nbl: &NetBufferList, direction: Direction) -> Result<Key, String> {
// Get bytes
let mut headers = [0; smoltcp::wire::IPV6_HEADER_LEN + smoltcp::wire::TCP_HEADER_LEN];
let Ok(()) = nbl.read_bytes(&mut headers) else {
return Err("failed to get net_buffer data".to_string());
};
// Parse packet
let ip_packet = Ipv6Packet::new_unchecked(&headers);
let (src_port, dst_port) = get_ports(
&headers[smoltcp::wire::IPV6_HEADER_LEN..],
ip_packet.next_header(),
);
// Build key
match direction {
Direction::Outbound => Ok(Key {
protocol: ip_packet.next_header(),
local_address: IpAddress::Ipv6(ip_packet.src_addr()),
local_port: src_port,
remote_address: IpAddress::Ipv6(ip_packet.dst_addr()),
remote_port: dst_port,
}),
Direction::Inbound => Ok(Key {
protocol: ip_packet.next_header(),
local_address: IpAddress::Ipv6(ip_packet.dst_addr()),
local_port: dst_port,
remote_address: IpAddress::Ipv6(ip_packet.src_addr()),
remote_port: src_port,
}),
}
}
// Converts a given key into connection information.
//
// This function takes a key, packet id, process id, and direction as input.
// It then uses these to create a new `ConnectionInfoV6` or `ConnectionInfoV4` object,
// depending on whether the IP addresses in the key are IPv6 or IPv4 respectively.
//
// # Arguments
//
// * `key` - A reference to the key object containing the connection details.
// * `packet_id` - The id of the packet.
// * `process_id` - The id of the process.
// * `direction` - The direction of the connection.
//
// # Returns
//
// * `Some(Box<dyn Info>)` - A boxed `Info` trait object if the key contains valid IPv4 or IPv6 addresses.
// * `None` - If the key does not contain valid IPv4 or IPv6 addresses.
// pub fn key_to_connection_info(
// key: &Key,
// packet_id: u64,
// process_id: u64,
// direction: Direction,
// payload: &[u8],
// ) -> Option<Info> {
// let (local_port, remote_port) = match key.protocol {
// IpProtocol::Tcp | IpProtocol::Udp => (key.local_port, key.remote_port),
// _ => (0, 0),
// };
// match (key.local_address, key.remote_address) {
// (IpAddress::Ipv6(local_ip), IpAddress::Ipv6(remote_ip)) if key.is_ipv6() => {
// Some(protocol::info::connection_info_v6(
// packet_id,
// process_id,
// direction as u8,
// u8::from(key.protocol),
// local_ip.0,
// remote_ip.0,
// local_port,
// remote_port,
// payload,
// ))
// }
// (IpAddress::Ipv4(local_ip), IpAddress::Ipv4(remote_ip)) => {
// Some(protocol::info::connection_info_v4(
// packet_id,
// process_id,
// direction as u8,
// u8::from(key.protocol),
// local_ip.0,
// remote_ip.0,
// local_port,
// remote_port,
// payload,
// ))
// }
// _ => None,
// }
// }

View File

@@ -0,0 +1,203 @@
use smoltcp::wire::{Ipv4Address, Ipv6Address};
use wdk::filter_engine::{callout_data::CalloutData, layer, net_buffer::NetBufferListIter};
use crate::{bandwidth, connection::Direction};
pub fn stream_layer_tcp_v4(data: CalloutData) {
let Some(device) = crate::entry::get_device() else {
return;
};
let mut direction = Direction::Outbound;
let data_length = if let Some(packet) = data.get_stream_callout_packet() {
if packet.is_receive() {
direction = Direction::Inbound;
}
packet.get_data_len()
} else {
return;
};
type Fields = layer::FieldsStreamV4;
let local_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpLocalAddress as usize)
.to_be_bytes(),
);
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
let remote_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpRemoteAddress as usize)
.to_be_bytes(),
);
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_tcp_v4_tx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
Direction::Inbound => {
device.bandwidth_stats.update_tcp_v4_rx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
}
}
pub fn stream_layer_tcp_v6(data: CalloutData) {
let Some(device) = crate::entry::get_device() else {
return;
};
let mut direction = Direction::Outbound;
let data_length = if let Some(packet) = data.get_stream_callout_packet() {
if packet.is_receive() {
direction = Direction::Inbound;
}
packet.get_data_len()
} else {
return;
};
type Fields = layer::FieldsStreamV6;
if data_length == 0 {
return;
}
let local_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
let remote_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_tcp_v6_tx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
Direction::Inbound => {
device.bandwidth_stats.update_tcp_v6_rx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
}
}
pub fn stream_layer_udp_v4(data: CalloutData) {
let Some(device) = crate::entry::get_device() else {
return;
};
let mut data_length: usize = 0;
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
type Fields = layer::FieldsDatagramDataV4;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;
}
let local_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpLocalAddress as usize)
.to_be_bytes(),
);
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
let remote_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpRemoteAddress as usize)
.to_be_bytes(),
);
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_udp_v4_tx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
Direction::Inbound => {
device.bandwidth_stats.update_udp_v4_rx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
}
}
pub fn stream_layer_udp_v6(data: CalloutData) {
let Some(device) = crate::entry::get_device() else {
return;
};
let mut data_length: usize = 0;
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
type Fields = layer::FieldsDatagramDataV6;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;
}
let local_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
let remote_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_udp_v6_tx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
Direction::Inbound => {
device.bandwidth_stats.update_udp_v6_rx(
bandwidth::Key {
local_ip,
local_port,
remote_ip,
remote_port,
},
data_length,
);
}
}
}