Add rust kext to the mono repo
This commit is contained in:
537
windows_kext/driver/src/ale_callouts.rs
Normal file
537
windows_kext/driver/src/ale_callouts.rs
Normal file
@@ -0,0 +1,537 @@
|
||||
use crate::connection::{Connection, ConnectionV4, ConnectionV6, Direction, Verdict};
|
||||
use crate::connection_map::Key;
|
||||
use crate::device::{Device, Packet};
|
||||
|
||||
use crate::info;
|
||||
use smoltcp::wire::{
|
||||
IpAddress, IpProtocol, Ipv4Address, Ipv6Address, IPV4_HEADER_LEN, IPV6_HEADER_LEN,
|
||||
};
|
||||
use wdk::filter_engine::callout_data::CalloutData;
|
||||
use wdk::filter_engine::layer::{
|
||||
self, FieldsAleAuthConnectV4, FieldsAleAuthConnectV6, FieldsAleAuthRecvAcceptV4,
|
||||
FieldsAleAuthRecvAcceptV6, ValueType,
|
||||
};
|
||||
use wdk::filter_engine::net_buffer::NetBufferList;
|
||||
use wdk::filter_engine::packet::{Injector, TransportPacketList};
|
||||
|
||||
// ALE Layers
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
struct AleLayerData {
|
||||
is_ipv6: bool,
|
||||
reauthorize: bool,
|
||||
process_id: u64,
|
||||
protocol: IpProtocol,
|
||||
direction: Direction,
|
||||
local_ip: IpAddress,
|
||||
local_port: u16,
|
||||
remote_ip: IpAddress,
|
||||
remote_port: u16,
|
||||
interface_index: u32,
|
||||
sub_interface_index: u32,
|
||||
}
|
||||
|
||||
impl AleLayerData {
|
||||
fn as_key(&self) -> Key {
|
||||
let mut local_port = 0;
|
||||
let mut remote_port = 0;
|
||||
match self.protocol {
|
||||
IpProtocol::Tcp | IpProtocol::Udp => {
|
||||
local_port = self.local_port;
|
||||
remote_port = self.remote_port;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Key {
|
||||
protocol: self.protocol,
|
||||
local_address: self.local_ip,
|
||||
local_port,
|
||||
remote_address: self.remote_ip,
|
||||
remote_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_protocol(data: &CalloutData, index: usize) -> IpProtocol {
|
||||
IpProtocol::from(data.get_value_u8(index))
|
||||
}
|
||||
|
||||
fn get_ipv4_address(data: &CalloutData, index: usize) -> IpAddress {
|
||||
IpAddress::Ipv4(Ipv4Address::from_bytes(
|
||||
&data.get_value_u32(index).to_be_bytes(),
|
||||
))
|
||||
}
|
||||
|
||||
fn get_ipv6_address(data: &CalloutData, index: usize) -> IpAddress {
|
||||
IpAddress::Ipv6(Ipv6Address::from_bytes(data.get_value_byte_array16(index)))
|
||||
}
|
||||
|
||||
pub fn ale_layer_connect_v4(data: CalloutData) {
|
||||
type Fields = FieldsAleAuthConnectV4;
|
||||
let ale_data = AleLayerData {
|
||||
is_ipv6: false,
|
||||
reauthorize: data.is_reauthorize(Fields::Flags as usize),
|
||||
process_id: data.get_process_id().unwrap_or(0),
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
direction: Direction::Outbound,
|
||||
local_ip: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_ip: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
interface_index: 0,
|
||||
sub_interface_index: 0,
|
||||
};
|
||||
|
||||
ale_layer_auth(data, ale_data);
|
||||
}
|
||||
|
||||
pub fn ale_layer_accept_v4(data: CalloutData) {
|
||||
type Fields = FieldsAleAuthRecvAcceptV4;
|
||||
let ale_data = AleLayerData {
|
||||
is_ipv6: false,
|
||||
reauthorize: data.is_reauthorize(Fields::Flags as usize),
|
||||
process_id: data.get_process_id().unwrap_or(0),
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
direction: Direction::Inbound,
|
||||
local_ip: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_ip: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
|
||||
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
|
||||
};
|
||||
ale_layer_auth(data, ale_data);
|
||||
}
|
||||
|
||||
pub fn ale_layer_connect_v6(data: CalloutData) {
|
||||
type Fields = FieldsAleAuthConnectV6;
|
||||
|
||||
let ale_data = AleLayerData {
|
||||
is_ipv6: true,
|
||||
reauthorize: data.is_reauthorize(Fields::Flags as usize),
|
||||
process_id: data.get_process_id().unwrap_or(0),
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
direction: Direction::Outbound,
|
||||
local_ip: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_ip: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
|
||||
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
|
||||
};
|
||||
|
||||
ale_layer_auth(data, ale_data);
|
||||
}
|
||||
|
||||
pub fn ale_layer_accept_v6(data: CalloutData) {
|
||||
type Fields = FieldsAleAuthRecvAcceptV6;
|
||||
let ale_data = AleLayerData {
|
||||
is_ipv6: true,
|
||||
reauthorize: data.is_reauthorize(Fields::Flags as usize),
|
||||
process_id: data.get_process_id().unwrap_or(0),
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
direction: Direction::Inbound,
|
||||
local_ip: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_ip: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
interface_index: data.get_value_u32(Fields::InterfaceIndex as usize),
|
||||
sub_interface_index: data.get_value_u32(Fields::SubInterfaceIndex as usize),
|
||||
};
|
||||
ale_layer_auth(data, ale_data);
|
||||
}
|
||||
|
||||
fn ale_layer_auth(mut data: CalloutData, ale_data: AleLayerData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
|
||||
match ale_data.protocol {
|
||||
IpProtocol::Tcp | IpProtocol::Udp => {
|
||||
// Only TCP and UDP make sense to be supported in the ALE layer.
|
||||
// Everything else is not associated with a connection and will be handled in the packet layer.
|
||||
}
|
||||
_ => {
|
||||
// Outbound: Will be handled by packet layer next.
|
||||
// Inbound: Was already handled by the packet layer.
|
||||
data.action_permit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let key = ale_data.as_key();
|
||||
|
||||
// Check if connection is already in cache.
|
||||
let verdict = if ale_data.is_ipv6 {
|
||||
device
|
||||
.connection_cache
|
||||
.read_connection_v6(&key, |conn| -> Option<Verdict> {
|
||||
// Function is behind spin lock, just copy and return.
|
||||
Some(conn.verdict)
|
||||
})
|
||||
} else {
|
||||
device
|
||||
.connection_cache
|
||||
.read_connection_v4(&ale_data.as_key(), |conn| -> Option<Verdict> {
|
||||
// Function is behind spin lock, just copy and return.
|
||||
Some(conn.verdict)
|
||||
})
|
||||
};
|
||||
|
||||
// Connection already in cache.
|
||||
if let Some(verdict) = verdict {
|
||||
crate::dbg!("processing existing connection: {} {}", key, verdict);
|
||||
match verdict {
|
||||
// No verdict yet
|
||||
Verdict::Undecided => {
|
||||
crate::dbg!("saving packet: {}", key);
|
||||
// Connection is already pended. Save packet and wait for verdict.
|
||||
match save_packet(device, &mut data, &ale_data, false) {
|
||||
Ok(packet) => {
|
||||
let info = device.packet_cache.push(
|
||||
(key, packet),
|
||||
ale_data.process_id,
|
||||
ale_data.direction,
|
||||
true,
|
||||
);
|
||||
if let Some(info) = info {
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
crate::err!("failed to pend packet: {}", err);
|
||||
}
|
||||
};
|
||||
data.block_and_absorb();
|
||||
}
|
||||
// There is a verdict
|
||||
Verdict::PermanentAccept
|
||||
| Verdict::Accept
|
||||
| Verdict::RedirectNameServer
|
||||
| Verdict::RedirectTunnel => {
|
||||
// Continue to packet layer.
|
||||
data.action_permit();
|
||||
}
|
||||
Verdict::PermanentBlock | Verdict::Undeterminable | Verdict::Failed => {
|
||||
// Packet layer will not see this connection.
|
||||
crate::dbg!("permanent block {}", key);
|
||||
data.action_block();
|
||||
}
|
||||
Verdict::PermanentDrop => {
|
||||
// Packet layer will not see this connection.
|
||||
crate::dbg!("permanent drop {}", key);
|
||||
data.block_and_absorb();
|
||||
}
|
||||
Verdict::Block => {
|
||||
if let Direction::Outbound = ale_data.direction {
|
||||
// Handled by packet layer.
|
||||
data.action_permit();
|
||||
} else {
|
||||
// packet layer will still see the packets.
|
||||
data.action_block();
|
||||
}
|
||||
}
|
||||
Verdict::Drop => {
|
||||
if let Direction::Outbound = ale_data.direction {
|
||||
// Handled by packet layer.
|
||||
data.action_permit();
|
||||
} else {
|
||||
// packet layer will still see the packets.
|
||||
data.block_and_absorb();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
crate::dbg!("pending connection: {} {}", key, ale_data.direction);
|
||||
// Only first packet of a connection can be pended: reauthorize == false
|
||||
let can_pend_connection = !ale_data.reauthorize;
|
||||
match save_packet(device, &mut data, &ale_data, can_pend_connection) {
|
||||
Ok(packet) => {
|
||||
let info = device.packet_cache.push(
|
||||
(key, packet),
|
||||
ale_data.process_id,
|
||||
ale_data.direction,
|
||||
true,
|
||||
);
|
||||
if let Some(info) = info {
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
crate::err!("failed to pend packet: {}", err);
|
||||
}
|
||||
};
|
||||
|
||||
// Connection is not in cache, add it.
|
||||
crate::dbg!("adding connection: {} PID: {}", key, ale_data.process_id);
|
||||
if ale_data.is_ipv6 {
|
||||
let conn =
|
||||
ConnectionV6::from_key(&key, ale_data.process_id, ale_data.direction).unwrap();
|
||||
device.connection_cache.add_connection_v6(conn);
|
||||
} else {
|
||||
let conn =
|
||||
ConnectionV4::from_key(&key, ale_data.process_id, ale_data.direction).unwrap();
|
||||
device.connection_cache.add_connection_v4(conn);
|
||||
}
|
||||
|
||||
// Drop packet. It will be re-injected after Portmaster returns a verdict.
|
||||
data.block_and_absorb();
|
||||
}
|
||||
}
|
||||
|
||||
fn save_packet(
|
||||
device: &Device,
|
||||
callout_data: &mut CalloutData,
|
||||
ale_data: &AleLayerData,
|
||||
pend: bool,
|
||||
) -> Result<Packet, alloc::string::String> {
|
||||
let mut packet_list = None;
|
||||
let mut save_packet_list = true;
|
||||
match ale_data.protocol {
|
||||
IpProtocol::Tcp => {
|
||||
if let Direction::Outbound = ale_data.direction {
|
||||
// Only time a packet data is missing is during connect state of outbound TCP connection.
|
||||
// Don't save packet list only if connection is outbound, reauthorize is false and the protocol is TCP.
|
||||
save_packet_list = ale_data.reauthorize;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
if save_packet_list {
|
||||
packet_list = create_packet_list(device, callout_data, ale_data);
|
||||
}
|
||||
if pend && matches!(ale_data.protocol, IpProtocol::Tcp | IpProtocol::Udp) {
|
||||
match callout_data.pend_operation(packet_list) {
|
||||
Ok(classify_defer) => Ok(Packet::AleLayer(classify_defer)),
|
||||
Err(err) => Err(alloc::format!("failed to defer connection: {}", err)),
|
||||
}
|
||||
} else {
|
||||
Ok(Packet::AleLayer(callout_data.pend_filter_rest(packet_list)))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_packet_list(
|
||||
device: &Device,
|
||||
callout_data: &mut CalloutData,
|
||||
ale_data: &AleLayerData,
|
||||
) -> Option<TransportPacketList> {
|
||||
let mut nbl = NetBufferList::new(callout_data.get_layer_data() as _);
|
||||
let mut inbound = false;
|
||||
if let Direction::Inbound = ale_data.direction {
|
||||
if ale_data.is_ipv6 {
|
||||
nbl.retreat(IPV6_HEADER_LEN as u32, true);
|
||||
} else {
|
||||
nbl.retreat(IPV4_HEADER_LEN as u32, true);
|
||||
}
|
||||
inbound = true;
|
||||
}
|
||||
|
||||
let address: &[u8] = match &ale_data.remote_ip {
|
||||
IpAddress::Ipv4(address) => &address.0,
|
||||
IpAddress::Ipv6(address) => &address.0,
|
||||
};
|
||||
if let Ok(clone) = nbl.clone(&device.network_allocator) {
|
||||
return Some(Injector::from_ale_callout(
|
||||
ale_data.is_ipv6,
|
||||
callout_data,
|
||||
clone,
|
||||
address,
|
||||
inbound,
|
||||
ale_data.interface_index,
|
||||
ale_data.sub_interface_index,
|
||||
));
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
pub fn endpoint_closure_v4(data: CalloutData) {
|
||||
type Fields = layer::FieldsAleEndpointClosureV4;
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let ip_address_type = data.get_value_type(Fields::IpLocalAddress as usize);
|
||||
if let ValueType::FwpUint32 = ip_address_type {
|
||||
let key = Key {
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
local_address: get_ipv4_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_address: get_ipv4_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
};
|
||||
|
||||
let conn = device.connection_cache.end_connection_v4(key);
|
||||
if let Some(conn) = conn {
|
||||
let info = protocol::info::connection_end_event_v4_info(
|
||||
data.get_process_id().unwrap_or(0),
|
||||
conn.get_direction() as u8,
|
||||
u8::from(get_protocol(&data, Fields::IpProtocol as usize)),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
} else {
|
||||
// Invalid ip address type. Just ignore the error.
|
||||
// err!(
|
||||
// device.logger,
|
||||
// "unknown ipv4 address type: {:?}",
|
||||
// ip_address_type
|
||||
// );
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_closure_v6(data: CalloutData) {
|
||||
type Fields = layer::FieldsAleEndpointClosureV6;
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let local_ip_address_type = data.get_value_type(Fields::IpLocalAddress as usize);
|
||||
let remote_ip_address_type = data.get_value_type(Fields::IpRemoteAddress as usize);
|
||||
|
||||
if let ValueType::FwpByteArray16Type = local_ip_address_type {
|
||||
if let ValueType::FwpByteArray16Type = remote_ip_address_type {
|
||||
let key = Key {
|
||||
protocol: get_protocol(&data, Fields::IpProtocol as usize),
|
||||
local_address: get_ipv6_address(&data, Fields::IpLocalAddress as usize),
|
||||
local_port: data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
remote_address: get_ipv6_address(&data, Fields::IpRemoteAddress as usize),
|
||||
remote_port: data.get_value_u16(Fields::IpRemotePort as usize),
|
||||
};
|
||||
|
||||
let conn = device.connection_cache.end_connection_v6(key);
|
||||
if let Some(conn) = conn {
|
||||
let info = protocol::info::connection_end_event_v6_info(
|
||||
data.get_process_id().unwrap_or(0),
|
||||
conn.get_direction() as u8,
|
||||
u8::from(get_protocol(&data, Fields::IpProtocol as usize)),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ale_resource_monitor(data: CalloutData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
match data.layer {
|
||||
layer::Layer::AleResourceAssignmentV4Discard => {
|
||||
type Fields = layer::FieldsAleResourceAssignmentV4;
|
||||
if let Some(conns) = device.connection_cache.end_all_on_port_v4((
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
)) {
|
||||
let process_id = data.get_process_id().unwrap_or(0);
|
||||
info!(
|
||||
"Port {}/{} Ipv4 assign request discarded pid={}",
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
process_id,
|
||||
);
|
||||
for conn in conns {
|
||||
let info = protocol::info::connection_end_event_v4_info(
|
||||
process_id,
|
||||
conn.get_direction() as u8,
|
||||
data.get_value_u8(Fields::IpProtocol as usize),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
layer::Layer::AleResourceAssignmentV6Discard => {
|
||||
type Fields = layer::FieldsAleResourceAssignmentV6;
|
||||
if let Some(conns) = device.connection_cache.end_all_on_port_v6((
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
)) {
|
||||
let process_id = data.get_process_id().unwrap_or(0);
|
||||
info!(
|
||||
"Port {}/{} Ipv6 assign request discarded pid={}",
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
process_id,
|
||||
);
|
||||
for conn in conns {
|
||||
let info = protocol::info::connection_end_event_v6_info(
|
||||
process_id,
|
||||
conn.get_direction() as u8,
|
||||
data.get_value_u8(Fields::IpProtocol as usize),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
layer::Layer::AleResourceReleaseV4 => {
|
||||
type Fields = layer::FieldsAleResourceReleaseV4;
|
||||
if let Some(conns) = device.connection_cache.end_all_on_port_v4((
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
)) {
|
||||
let process_id = data.get_process_id().unwrap_or(0);
|
||||
info!(
|
||||
"Port {}/{} released pid={}",
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
process_id,
|
||||
);
|
||||
for conn in conns {
|
||||
let info = protocol::info::connection_end_event_v4_info(
|
||||
process_id,
|
||||
conn.get_direction() as u8,
|
||||
data.get_value_u8(Fields::IpProtocol as usize),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
layer::Layer::AleResourceReleaseV6 => {
|
||||
type Fields = layer::FieldsAleResourceReleaseV6;
|
||||
if let Some(conns) = device.connection_cache.end_all_on_port_v6((
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
)) {
|
||||
let process_id = data.get_process_id().unwrap_or(0);
|
||||
info!(
|
||||
"Port {}/{} released pid={}",
|
||||
data.get_value_u16(Fields::IpLocalPort as usize),
|
||||
get_protocol(&data, Fields::IpProtocol as usize),
|
||||
process_id,
|
||||
);
|
||||
for conn in conns {
|
||||
let info = protocol::info::connection_end_event_v6_info(
|
||||
process_id,
|
||||
conn.get_direction() as u8,
|
||||
data.get_value_u8(Fields::IpProtocol as usize),
|
||||
conn.local_address.0,
|
||||
conn.remote_address.0,
|
||||
conn.local_port,
|
||||
conn.remote_port,
|
||||
);
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
25
windows_kext/driver/src/array_holder.rs
Normal file
25
windows_kext/driver/src/array_holder.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use core::cell::RefCell;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
pub struct ArrayHolder(RefCell<Option<Vec<u8>>>);
|
||||
unsafe impl Sync for ArrayHolder {}
|
||||
|
||||
impl ArrayHolder {
|
||||
pub const fn default() -> Self {
|
||||
Self(RefCell::new(None))
|
||||
}
|
||||
|
||||
pub fn save(&self, data: &[u8]) {
|
||||
if let Ok(mut opt) = self.0.try_borrow_mut() {
|
||||
opt.replace(data.to_vec());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(&self) -> Option<Vec<u8>> {
|
||||
if let Ok(mut opt) = self.0.try_borrow_mut() {
|
||||
return opt.take();
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
293
windows_kext/driver/src/bandwidth.rs
Normal file
293
windows_kext/driver/src/bandwidth.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info};
|
||||
use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address};
|
||||
use wdk::rw_spin_lock::RwSpinLock;
|
||||
|
||||
use crate::driver_hashmap::DeviceHashMap;
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
|
||||
pub struct Key<Address>
|
||||
where
|
||||
Address: Eq + PartialEq,
|
||||
{
|
||||
pub local_ip: Address,
|
||||
pub local_port: u16,
|
||||
pub remote_ip: Address,
|
||||
pub remote_port: u16,
|
||||
}
|
||||
|
||||
struct Value {
|
||||
received_bytes: usize,
|
||||
transmitted_bytes: usize,
|
||||
}
|
||||
|
||||
enum Direction {
|
||||
Tx(usize),
|
||||
Rx(usize),
|
||||
}
|
||||
pub struct Bandwidth {
|
||||
stats_tcp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
|
||||
stats_tcp_v4_lock: RwSpinLock,
|
||||
|
||||
stats_tcp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
|
||||
stats_tcp_v6_lock: RwSpinLock,
|
||||
|
||||
stats_udp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
|
||||
stats_udp_v4_lock: RwSpinLock,
|
||||
|
||||
stats_udp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
|
||||
stats_udp_v6_lock: RwSpinLock,
|
||||
}
|
||||
|
||||
impl Bandwidth {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
stats_tcp_v4: DeviceHashMap::new(),
|
||||
stats_tcp_v4_lock: RwSpinLock::default(),
|
||||
|
||||
stats_tcp_v6: DeviceHashMap::new(),
|
||||
stats_tcp_v6_lock: RwSpinLock::default(),
|
||||
|
||||
stats_udp_v4: DeviceHashMap::new(),
|
||||
stats_udp_v4_lock: RwSpinLock::default(),
|
||||
|
||||
stats_udp_v6: DeviceHashMap::new(),
|
||||
stats_udp_v6_lock: RwSpinLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_all_updates_tcp_v4(&mut self) -> Option<Info> {
|
||||
let stats_map;
|
||||
{
|
||||
let _guard = self.stats_tcp_v4_lock.write_lock();
|
||||
if self.stats_tcp_v4.is_empty() {
|
||||
return None;
|
||||
}
|
||||
stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new());
|
||||
}
|
||||
|
||||
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
|
||||
for (key, value) in stats_map.iter() {
|
||||
values.push(BandwidthValueV4 {
|
||||
local_ip: key.local_ip.0,
|
||||
local_port: key.local_port,
|
||||
remote_ip: key.remote_ip.0,
|
||||
remote_port: key.remote_port,
|
||||
transmitted_bytes: value.transmitted_bytes as u64,
|
||||
received_bytes: value.received_bytes as u64,
|
||||
});
|
||||
}
|
||||
Some(protocol::info::bandiwth_stats_array_v4(
|
||||
u8::from(IpProtocol::Tcp),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_all_updates_tcp_v6(&mut self) -> Option<Info> {
|
||||
let stats_map;
|
||||
{
|
||||
let _guard = self.stats_tcp_v6_lock.write_lock();
|
||||
if self.stats_tcp_v6.is_empty() {
|
||||
return None;
|
||||
}
|
||||
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
|
||||
}
|
||||
|
||||
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
|
||||
for (key, value) in stats_map.iter() {
|
||||
values.push(BandwidthValueV6 {
|
||||
local_ip: key.local_ip.0,
|
||||
local_port: key.local_port,
|
||||
remote_ip: key.remote_ip.0,
|
||||
remote_port: key.remote_port,
|
||||
transmitted_bytes: value.transmitted_bytes as u64,
|
||||
received_bytes: value.received_bytes as u64,
|
||||
});
|
||||
}
|
||||
Some(protocol::info::bandiwth_stats_array_v6(
|
||||
u8::from(IpProtocol::Tcp),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_all_updates_udp_v4(&mut self) -> Option<Info> {
|
||||
let stats_map;
|
||||
{
|
||||
let _guard = self.stats_udp_v4_lock.write_lock();
|
||||
if self.stats_udp_v4.is_empty() {
|
||||
return None;
|
||||
}
|
||||
stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new());
|
||||
}
|
||||
|
||||
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
|
||||
for (key, value) in stats_map.iter() {
|
||||
values.push(BandwidthValueV4 {
|
||||
local_ip: key.local_ip.0,
|
||||
local_port: key.local_port,
|
||||
remote_ip: key.remote_ip.0,
|
||||
remote_port: key.remote_port,
|
||||
transmitted_bytes: value.transmitted_bytes as u64,
|
||||
received_bytes: value.received_bytes as u64,
|
||||
});
|
||||
}
|
||||
Some(protocol::info::bandiwth_stats_array_v4(
|
||||
u8::from(IpProtocol::Udp),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_all_updates_udp_v6(&mut self) -> Option<Info> {
|
||||
let stats_map;
|
||||
{
|
||||
let _guard = self.stats_udp_v6_lock.write_lock();
|
||||
if self.stats_tcp_v6.is_empty() {
|
||||
return None;
|
||||
}
|
||||
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
|
||||
}
|
||||
|
||||
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
|
||||
for (key, value) in stats_map.iter() {
|
||||
values.push(BandwidthValueV6 {
|
||||
local_ip: key.local_ip.0,
|
||||
local_port: key.local_port,
|
||||
remote_ip: key.remote_ip.0,
|
||||
remote_port: key.remote_port,
|
||||
transmitted_bytes: value.transmitted_bytes as u64,
|
||||
received_bytes: value.received_bytes as u64,
|
||||
});
|
||||
}
|
||||
Some(protocol::info::bandiwth_stats_array_v6(
|
||||
u8::from(IpProtocol::Udp),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn update_tcp_v4_tx(&mut self, key: Key<Ipv4Address>, tx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_tcp_v4,
|
||||
&mut self.stats_tcp_v4_lock,
|
||||
key,
|
||||
Direction::Tx(tx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_tcp_v4_rx(&mut self, key: Key<Ipv4Address>, rx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_tcp_v4,
|
||||
&mut self.stats_tcp_v4_lock,
|
||||
key,
|
||||
Direction::Rx(rx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_tcp_v6_tx(&mut self, key: Key<Ipv6Address>, tx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_tcp_v6,
|
||||
&mut self.stats_tcp_v6_lock,
|
||||
key,
|
||||
Direction::Tx(tx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_tcp_v6_rx(&mut self, key: Key<Ipv6Address>, rx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_tcp_v6,
|
||||
&mut self.stats_tcp_v6_lock,
|
||||
key,
|
||||
Direction::Rx(rx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_udp_v4_tx(&mut self, key: Key<Ipv4Address>, tx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_udp_v4,
|
||||
&mut self.stats_udp_v4_lock,
|
||||
key,
|
||||
Direction::Tx(tx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_udp_v4_rx(&mut self, key: Key<Ipv4Address>, rx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_udp_v4,
|
||||
&mut self.stats_udp_v4_lock,
|
||||
key,
|
||||
Direction::Rx(rx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_udp_v6_tx(&mut self, key: Key<Ipv6Address>, tx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_udp_v6,
|
||||
&mut self.stats_udp_v6_lock,
|
||||
key,
|
||||
Direction::Tx(tx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn update_udp_v6_rx(&mut self, key: Key<Ipv6Address>, rx_bytes: usize) {
|
||||
Self::update(
|
||||
&mut self.stats_udp_v6,
|
||||
&mut self.stats_udp_v6_lock,
|
||||
key,
|
||||
Direction::Rx(rx_bytes),
|
||||
);
|
||||
}
|
||||
|
||||
fn update<Address: Eq + PartialEq + core::hash::Hash>(
|
||||
map: &mut DeviceHashMap<Key<Address>, Value>,
|
||||
lock: &mut RwSpinLock,
|
||||
key: Key<Address>,
|
||||
bytes: Direction,
|
||||
) {
|
||||
let _guard = lock.write_lock();
|
||||
if let Some(value) = map.get_mut(&key) {
|
||||
match bytes {
|
||||
Direction::Tx(bytes_count) => value.transmitted_bytes += bytes_count,
|
||||
Direction::Rx(bytes_count) => value.received_bytes += bytes_count,
|
||||
}
|
||||
} else {
|
||||
let mut received_bytes = 0;
|
||||
let mut transmitted_bytes = 0;
|
||||
match bytes {
|
||||
Direction::Tx(bytes_count) => transmitted_bytes += bytes_count,
|
||||
Direction::Rx(bytes_count) => received_bytes += bytes_count,
|
||||
}
|
||||
map.insert(
|
||||
key,
|
||||
Value {
|
||||
received_bytes,
|
||||
transmitted_bytes,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_entries_count(&self) -> usize {
|
||||
let mut size = 0;
|
||||
{
|
||||
let values = &self.stats_tcp_v4.values();
|
||||
let _guard = self.stats_tcp_v4_lock.read_lock();
|
||||
size += values.len();
|
||||
}
|
||||
{
|
||||
let values = &self.stats_tcp_v6.values();
|
||||
let _guard = self.stats_tcp_v6_lock.read_lock();
|
||||
size += values.len();
|
||||
}
|
||||
{
|
||||
let values = &self.stats_udp_v4.values();
|
||||
let _guard = self.stats_udp_v4_lock.read_lock();
|
||||
size += values.len();
|
||||
}
|
||||
{
|
||||
let values = &self.stats_udp_v6.values();
|
||||
let _guard = self.stats_udp_v6_lock.read_lock();
|
||||
size += values.len();
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
}
|
||||
185
windows_kext/driver/src/callouts.rs
Normal file
185
windows_kext/driver/src/callouts.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use alloc::vec::Vec;
|
||||
use wdk::filter_engine::callout::FilterType;
|
||||
use wdk::{
|
||||
consts,
|
||||
filter_engine::{callout::Callout, layer::Layer},
|
||||
};
|
||||
|
||||
use crate::{ale_callouts, packet_callouts, stream_callouts};
|
||||
|
||||
pub fn get_callout_vec() -> Vec<Callout> {
|
||||
alloc::vec![
|
||||
// -----------------------------------------
|
||||
// ALE Auth layers
|
||||
Callout::new(
|
||||
"AleLayerOutboundV4",
|
||||
"ALE layer for outbound connection for ipv4",
|
||||
0x58545073_f893_454c_bbea_a57bc964f46d,
|
||||
Layer::AleAuthConnectV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::Resettable,
|
||||
ale_callouts::ale_layer_connect_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"AleLayerInboundV4",
|
||||
"ALE layer for inbound connections for ipv4",
|
||||
0xc6021395_0724_4e2c_ae20_3dde51fc3c68,
|
||||
Layer::AleAuthRecvAcceptV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::Resettable,
|
||||
ale_callouts::ale_layer_accept_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"AleLayerOutboundV6",
|
||||
"ALE layer for outbound connections for ipv6",
|
||||
0x4bd2a080_2585_478d_977c_7f340c6bc3d4,
|
||||
Layer::AleAuthConnectV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::Resettable,
|
||||
ale_callouts::ale_layer_connect_v6,
|
||||
),
|
||||
Callout::new(
|
||||
"AleLayerInboundV6",
|
||||
"ALE layer for inbound connections for ipv6",
|
||||
0xd24480da_38fa_4099_9383_b5c83b69e4f2,
|
||||
Layer::AleAuthRecvAcceptV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::Resettable,
|
||||
ale_callouts::ale_layer_accept_v6,
|
||||
),
|
||||
// -----------------------------------------
|
||||
// ALE connection end layers
|
||||
Callout::new(
|
||||
"AleEndpointClosureV4",
|
||||
"ALE layer for indicating closing of connection for ipv4",
|
||||
0x58f02845_ace9_4455_ac80_8a84b86fe566,
|
||||
Layer::AleEndpointClosureV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
ale_callouts::endpoint_closure_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"AleEndpointClosureV6",
|
||||
"ALE layer for indicating closing of connection for ipv6",
|
||||
0x2bc82359_9dc5_4315_9c93_c89467e283ce,
|
||||
Layer::AleEndpointClosureV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
ale_callouts::endpoint_closure_v6,
|
||||
),
|
||||
// -----------------------------------------
|
||||
// ALE resource assignment and release.
|
||||
// Callout::new(
|
||||
// "AleResourceAssignmentV4",
|
||||
// "Ipv4 Port assignment monitoring",
|
||||
// 0x6b9d1985_6f75_4d05_b9b5_1607e187906f,
|
||||
// Layer::AleResourceAssignmentV4Discard,
|
||||
// consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
// FilterType::NonResettable,
|
||||
// ale_callouts::ale_resource_monitor,
|
||||
// ),
|
||||
Callout::new(
|
||||
"AleResourceReleaseV4",
|
||||
"Ipv4 Port release monitor",
|
||||
0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7,
|
||||
Layer::AleResourceReleaseV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
ale_callouts::ale_resource_monitor,
|
||||
),
|
||||
// Callout::new(
|
||||
// "AleResourceAssignmentV6",
|
||||
// "Ipv4 Port assignment monitor",
|
||||
// 0xb0d02299_3d3e_437d_916a_f0e96a60cc18,
|
||||
// Layer::AleResourceAssignmentV6Discard,
|
||||
// consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
// FilterType::NonResettable,
|
||||
// ale_callouts::ale_resource_monitor,
|
||||
// ),
|
||||
Callout::new(
|
||||
"AleResourceReleaseV6",
|
||||
"Ipv6 Port release monitor",
|
||||
0x6cf36e04_e656_42c3_8cac_a1ce05328bd1,
|
||||
Layer::AleResourceReleaseV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
ale_callouts::ale_resource_monitor,
|
||||
),
|
||||
// -----------------------------------------
|
||||
// Stream layer
|
||||
Callout::new(
|
||||
"StreamLayerV4",
|
||||
"Stream layer for ipv4",
|
||||
0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780,
|
||||
Layer::StreamV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
stream_callouts::stream_layer_tcp_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"StreamLayerV6",
|
||||
"Stream layer for ipv6",
|
||||
0x66c549b3_11e2_4b27_8f73_856e6fd82baa,
|
||||
Layer::StreamV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
stream_callouts::stream_layer_tcp_v6,
|
||||
),
|
||||
Callout::new(
|
||||
"DatagramDataLayerV4",
|
||||
"DatagramData layer for ipv4",
|
||||
0xe7eeeaba_168a_45bb_8747_e1a702feb2c5,
|
||||
Layer::DatagramDataV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
stream_callouts::stream_layer_udp_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"DatagramDataLayerV6",
|
||||
"DatagramData layer for ipv4",
|
||||
0xb25862cd_f744_4452_b14a_d0c1e5a25b30,
|
||||
Layer::DatagramDataV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
FilterType::NonResettable,
|
||||
stream_callouts::stream_layer_udp_v6,
|
||||
),
|
||||
// -----------------------------------------
|
||||
// Packet layers
|
||||
Callout::new(
|
||||
"IPPacketOutboundV4",
|
||||
"IP packet outbound network layer callout for Ipv4",
|
||||
0xf3183afe_dc35_49f1_8ea2_b16b5666dd36,
|
||||
Layer::OutboundIppacketV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::NonResettable,
|
||||
packet_callouts::ip_packet_layer_outbound_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketInboundV4",
|
||||
"IP packet inbound network layer callout for Ipv4",
|
||||
0xf0369374_203d_4bf0_83d2_b2ad3cc17a50,
|
||||
Layer::InboundIppacketV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::NonResettable,
|
||||
packet_callouts::ip_packet_layer_inbound_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketOutboundV6",
|
||||
"IP packet outbound network layer callout for Ipv6",
|
||||
0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a,
|
||||
Layer::OutboundIppacketV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::NonResettable,
|
||||
packet_callouts::ip_packet_layer_outbound_v6,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketInboundV6",
|
||||
"IP packet inbound network layer callout for Ipv6",
|
||||
0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0,
|
||||
Layer::InboundIppacketV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
FilterType::NonResettable,
|
||||
packet_callouts::ip_packet_layer_inbound_v6,
|
||||
)
|
||||
]
|
||||
}
|
||||
59
windows_kext/driver/src/common.rs
Normal file
59
windows_kext/driver/src/common.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use core::fmt::Display;
|
||||
|
||||
use num_derive::{FromPrimitive, ToPrimitive};
|
||||
|
||||
pub const ICMPV4_CODE_DESTINATION_UNREACHABLE: u32 = 3;
|
||||
pub const ICMPV4_CODE_DU_PORT_UNREACHABLE: u32 = 3; // Destination Unreachable (Port unreachable) ;
|
||||
pub const ICMPV4_CODE_DU_ADMINISTRATIVELY_PROHIBITED: u32 = 13; // Destination Unreachable (Communication Administratively Prohibited) ;
|
||||
|
||||
pub const ICMPV6_CODE_DESTINATION_UNREACHABLE: u32 = 1;
|
||||
pub const ICMPV6_CODE_DU_PORT_UNREACHABLE: u32 = 4; // Destination Unreachable (Port unreachable) ;
|
||||
|
||||
enum Direction {
|
||||
Outbound = 0,
|
||||
Inbound = 1,
|
||||
}
|
||||
|
||||
const SIOCTL_TYPE: u32 = 40000;
|
||||
macro_rules! ctl_code {
|
||||
($device_type:expr, $function:expr, $method:expr, $access:expr) => {
|
||||
($device_type << 16) | ($access << 14) | ($function << 2) | $method
|
||||
};
|
||||
}
|
||||
|
||||
pub const METHOD_BUFFERED: u32 = 0;
|
||||
pub const METHOD_IN_DIRECT: u32 = 1;
|
||||
pub const METHOD_OUT_DIRECT: u32 = 2;
|
||||
pub const METHOD_NEITHER: u32 = 3;
|
||||
|
||||
pub const FILE_READ_DATA: u32 = 0x0001; // file & pipe
|
||||
pub const FILE_WRITE_DATA: u32 = 0x0002; // file & pipe
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(FromPrimitive, ToPrimitive)]
|
||||
pub enum ControlCode {
|
||||
Version = ctl_code!(
|
||||
SIOCTL_TYPE,
|
||||
0x800,
|
||||
METHOD_BUFFERED,
|
||||
FILE_READ_DATA | FILE_WRITE_DATA
|
||||
),
|
||||
ShutdownRequest = ctl_code!(
|
||||
SIOCTL_TYPE,
|
||||
0x801,
|
||||
METHOD_BUFFERED,
|
||||
FILE_READ_DATA | FILE_WRITE_DATA
|
||||
),
|
||||
}
|
||||
|
||||
impl Display for ControlCode {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
ControlCode::Version => _ = write!(f, "Version"),
|
||||
ControlCode::ShutdownRequest => _ = write!(f, "Shutdown"),
|
||||
};
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
499
windows_kext/driver/src/connection.rs
Normal file
499
windows_kext/driver/src/connection.rs
Normal file
@@ -0,0 +1,499 @@
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
string::{String, ToString},
|
||||
};
|
||||
use core::{
|
||||
fmt::{Debug, Display},
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
use num_derive::FromPrimitive;
|
||||
use smoltcp::wire::{IpAddress, IpProtocol, Ipv4Address, Ipv6Address};
|
||||
|
||||
use crate::connection_map::Key;
|
||||
|
||||
pub static PM_DNS_PORT: u16 = 53;
|
||||
pub static PM_SPN_PORT: u16 = 717;
|
||||
|
||||
// Make sure this in sync with the Go version
|
||||
#[derive(Copy, Clone, FromPrimitive)]
|
||||
#[repr(u8)]
|
||||
#[rustfmt::skip]
|
||||
pub enum Verdict {
|
||||
Undecided = 0, // Undecided is the default status of new connections.
|
||||
Undeterminable = 1,
|
||||
Accept = 2,
|
||||
PermanentAccept = 3,
|
||||
Block = 4,
|
||||
PermanentBlock = 5,
|
||||
Drop = 6,
|
||||
PermanentDrop = 7,
|
||||
RedirectNameServer = 8,
|
||||
RedirectTunnel = 9,
|
||||
Failed = 10,
|
||||
}
|
||||
|
||||
impl Display for Verdict {
|
||||
#[rustfmt::skip]
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
Verdict::Undecided => write!(f, "Undecided"),
|
||||
Verdict::Undeterminable => write!(f, "Undeterminable"),
|
||||
Verdict::Accept => write!(f, "Accept"),
|
||||
Verdict::PermanentAccept => write!(f, "PermanentAccept"),
|
||||
Verdict::Block => write!(f, "Block"),
|
||||
Verdict::PermanentBlock => write!(f, "PermanentBlock"),
|
||||
Verdict::Drop => write!(f, "Drop"),
|
||||
Verdict::PermanentDrop => write!(f, "PermanentDrop"),
|
||||
Verdict::RedirectNameServer => write!(f, "RedirectNameServer"),
|
||||
Verdict::RedirectTunnel => write!(f, "RedirectTunnel"),
|
||||
Verdict::Failed => write!(f, "Failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Verdict {
|
||||
/// Returns true if the verdict is a redirect.
|
||||
pub fn is_redirect(&self) -> bool {
|
||||
matches!(self, Verdict::RedirectNameServer | Verdict::RedirectTunnel)
|
||||
}
|
||||
|
||||
/// Returns true if the verdict is a permanent verdict.
|
||||
pub fn is_permanent(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Verdict::PermanentAccept
|
||||
| Verdict::PermanentBlock
|
||||
| Verdict::PermanentDrop
|
||||
| Verdict::RedirectNameServer
|
||||
| Verdict::RedirectTunnel
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Direction of the connection.
|
||||
#[derive(Copy, Clone, FromPrimitive)]
|
||||
#[repr(u8)]
|
||||
pub enum Direction {
|
||||
Outbound = 0,
|
||||
Inbound = 1,
|
||||
}
|
||||
|
||||
impl Display for Direction {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
Direction::Outbound => write!(f, "Outbound"),
|
||||
Direction::Inbound => write!(f, "Inbound"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Direction {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectionExtra {
|
||||
pub(crate) end_timestamp: u64,
|
||||
pub(crate) direction: Direction,
|
||||
}
|
||||
|
||||
pub trait Connection {
|
||||
fn redirect_info(&self) -> Option<RedirectInfo> {
|
||||
let redirect_address = if self.is_ipv6() {
|
||||
IpAddress::Ipv6(Ipv6Address::LOOPBACK)
|
||||
} else {
|
||||
IpAddress::Ipv4(Ipv4Address::new(127, 0, 0, 1))
|
||||
};
|
||||
|
||||
match self.get_verdict() {
|
||||
Verdict::RedirectNameServer => Some(RedirectInfo {
|
||||
local_address: self.get_local_address(),
|
||||
remote_address: self.get_remote_address(),
|
||||
remote_port: self.get_remote_port(),
|
||||
redirect_port: PM_DNS_PORT,
|
||||
unify: false,
|
||||
redirect_address,
|
||||
}),
|
||||
Verdict::RedirectTunnel => Some(RedirectInfo {
|
||||
local_address: self.get_local_address(),
|
||||
remote_address: self.get_remote_address(),
|
||||
remote_port: self.get_remote_port(),
|
||||
redirect_port: PM_SPN_PORT,
|
||||
unify: true,
|
||||
redirect_address,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the key of the connection.
|
||||
fn get_key(&self) -> Key {
|
||||
Key {
|
||||
protocol: self.get_protocol(),
|
||||
local_address: self.get_local_address(),
|
||||
local_port: self.get_local_port(),
|
||||
remote_address: self.get_remote_address(),
|
||||
remote_port: self.get_remote_port(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the connection is equal to the given key. The key is considered equal if the remote port and address are equal.
|
||||
fn remote_equals(&self, key: &Key) -> bool;
|
||||
/// Returns true if the connection is equal to the given key for redirecting. The key is considered equal if the remote port and address are equal.
|
||||
fn redirect_equals(&self, key: &Key) -> bool;
|
||||
/// Returns the protocol of the connection.
|
||||
fn get_protocol(&self) -> IpProtocol;
|
||||
/// Returns the verdict of the connection.
|
||||
fn get_verdict(&self) -> Verdict;
|
||||
/// Returns the local address of the connection.
|
||||
fn get_local_address(&self) -> IpAddress;
|
||||
/// Returns the local port of the connection.
|
||||
fn get_local_port(&self) -> u16;
|
||||
/// Returns the remote address of the connection.
|
||||
fn get_remote_address(&self) -> IpAddress;
|
||||
/// Returns the remote port of the connection.
|
||||
fn get_remote_port(&self) -> u16;
|
||||
/// Returns true if the connection is an IPv6 connection.
|
||||
fn is_ipv6(&self) -> bool;
|
||||
/// Returns the direction of the connection.
|
||||
fn get_direction(&self) -> Direction;
|
||||
// Returns the process id of the connection.
|
||||
fn get_process_id(&self) -> u64;
|
||||
/// Ends the connection.
|
||||
fn end(&mut self, timestamp: u64);
|
||||
/// Returns true if the connection has ended.
|
||||
fn has_ended(&self) -> bool {
|
||||
self.get_end_time() > 0
|
||||
}
|
||||
/// Returns the timestamp when the connection ended.
|
||||
fn get_end_time(&self) -> u64;
|
||||
/// Returns the timestamp when the connection was last accessed.
|
||||
fn get_last_accessed_time(&self) -> u64;
|
||||
/// Sets the timestamp when the connection was last accessed.
|
||||
fn set_last_accessed_time(&self, timestamp: u64);
|
||||
}
|
||||
|
||||
pub struct ConnectionV4 {
|
||||
pub(crate) protocol: IpProtocol,
|
||||
pub(crate) local_address: Ipv4Address,
|
||||
pub(crate) local_port: u16,
|
||||
pub(crate) remote_address: Ipv4Address,
|
||||
pub(crate) remote_port: u16,
|
||||
pub(crate) verdict: Verdict,
|
||||
pub(crate) process_id: u64,
|
||||
pub(crate) last_accessed_timestamp: AtomicU64,
|
||||
pub(crate) extra: Box<ConnectionExtra>,
|
||||
}
|
||||
|
||||
pub struct ConnectionV6 {
|
||||
pub(crate) protocol: IpProtocol,
|
||||
pub(crate) local_address: Ipv6Address,
|
||||
pub(crate) local_port: u16,
|
||||
pub(crate) remote_address: Ipv6Address,
|
||||
pub(crate) remote_port: u16,
|
||||
pub(crate) verdict: Verdict,
|
||||
pub(crate) process_id: u64,
|
||||
pub(crate) last_accessed_timestamp: AtomicU64,
|
||||
pub(crate) extra: Box<ConnectionExtra>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RedirectInfo {
|
||||
pub(crate) local_address: IpAddress,
|
||||
pub(crate) remote_address: IpAddress,
|
||||
pub(crate) remote_port: u16,
|
||||
pub(crate) redirect_port: u16,
|
||||
pub(crate) unify: bool,
|
||||
pub(crate) redirect_address: IpAddress,
|
||||
}
|
||||
|
||||
impl ConnectionV4 {
|
||||
/// Creates a new ipv4 connection from the given key.
|
||||
pub fn from_key(key: &Key, process_id: u64, direction: Direction) -> Result<Self, String> {
|
||||
let IpAddress::Ipv4(local_address) = key.local_address else {
|
||||
return Err("wrong ip address version".to_string());
|
||||
};
|
||||
|
||||
let IpAddress::Ipv4(remote_address) = key.remote_address else {
|
||||
return Err("wrong ip address version".to_string());
|
||||
};
|
||||
|
||||
let timestamp = wdk::utils::get_system_timestamp_ms();
|
||||
|
||||
Ok(Self {
|
||||
protocol: key.protocol,
|
||||
local_address,
|
||||
local_port: key.local_port,
|
||||
remote_address,
|
||||
remote_port: key.remote_port,
|
||||
verdict: Verdict::Undecided,
|
||||
process_id,
|
||||
last_accessed_timestamp: AtomicU64::new(timestamp),
|
||||
extra: Box::new(ConnectionExtra {
|
||||
direction,
|
||||
end_timestamp: 0,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for ConnectionV4 {
|
||||
fn remote_equals(&self, key: &Key) -> bool {
|
||||
if self.remote_port != key.remote_port {
|
||||
return false;
|
||||
}
|
||||
if let IpAddress::Ipv4(remote_address) = &key.remote_address {
|
||||
return self.remote_address.eq(remote_address);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn get_key(&self) -> Key {
|
||||
Key {
|
||||
protocol: self.protocol,
|
||||
local_address: IpAddress::Ipv4(self.local_address),
|
||||
local_port: self.local_port,
|
||||
remote_address: IpAddress::Ipv4(self.remote_address),
|
||||
remote_port: self.remote_port,
|
||||
}
|
||||
}
|
||||
|
||||
fn redirect_equals(&self, key: &Key) -> bool {
|
||||
match self.verdict {
|
||||
Verdict::RedirectNameServer => {
|
||||
if key.remote_port != PM_DNS_PORT {
|
||||
return false;
|
||||
}
|
||||
|
||||
match key.remote_address {
|
||||
IpAddress::Ipv4(a) => a.is_loopback(),
|
||||
IpAddress::Ipv6(_) => false,
|
||||
}
|
||||
}
|
||||
Verdict::RedirectTunnel => {
|
||||
if key.remote_port != PM_SPN_PORT {
|
||||
return false;
|
||||
}
|
||||
key.local_address.eq(&key.remote_address)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_protocol(&self) -> IpProtocol {
|
||||
self.protocol
|
||||
}
|
||||
|
||||
fn get_verdict(&self) -> Verdict {
|
||||
self.verdict
|
||||
}
|
||||
|
||||
fn get_local_address(&self) -> IpAddress {
|
||||
IpAddress::Ipv4(self.local_address)
|
||||
}
|
||||
|
||||
fn get_local_port(&self) -> u16 {
|
||||
self.local_port
|
||||
}
|
||||
|
||||
fn get_remote_address(&self) -> IpAddress {
|
||||
IpAddress::Ipv4(self.remote_address)
|
||||
}
|
||||
|
||||
fn get_remote_port(&self) -> u16 {
|
||||
self.remote_port
|
||||
}
|
||||
|
||||
fn is_ipv6(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn get_process_id(&self) -> u64 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
fn get_direction(&self) -> Direction {
|
||||
self.extra.direction
|
||||
}
|
||||
|
||||
fn end(&mut self, timestamp: u64) {
|
||||
self.extra.end_timestamp = timestamp;
|
||||
}
|
||||
|
||||
fn get_end_time(&self) -> u64 {
|
||||
self.extra.end_timestamp
|
||||
}
|
||||
|
||||
fn get_last_accessed_time(&self) -> u64 {
|
||||
self.last_accessed_timestamp.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn set_last_accessed_time(&self, timestamp: u64) {
|
||||
self.last_accessed_timestamp
|
||||
.store(timestamp, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ConnectionV4 {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
protocol: self.protocol,
|
||||
local_address: self.local_address,
|
||||
local_port: self.local_port,
|
||||
remote_address: self.remote_address,
|
||||
remote_port: self.remote_port,
|
||||
verdict: self.verdict,
|
||||
process_id: self.process_id,
|
||||
last_accessed_timestamp: AtomicU64::new(
|
||||
self.last_accessed_timestamp.load(Ordering::Relaxed),
|
||||
),
|
||||
extra: self.extra.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionV6 {
|
||||
/// Creates a new ipv6 connection from the given key.
|
||||
pub fn from_key(key: &Key, process_id: u64, direction: Direction) -> Result<Self, String> {
|
||||
let IpAddress::Ipv6(local_address) = key.local_address else {
|
||||
return Err("wrong ip address version".to_string());
|
||||
};
|
||||
|
||||
let IpAddress::Ipv6(remote_address) = key.remote_address else {
|
||||
return Err("wrong ip address version".to_string());
|
||||
};
|
||||
let timestamp = wdk::utils::get_system_timestamp_ms();
|
||||
|
||||
Ok(Self {
|
||||
protocol: key.protocol,
|
||||
local_address,
|
||||
local_port: key.local_port,
|
||||
remote_address,
|
||||
remote_port: key.remote_port,
|
||||
verdict: Verdict::Undecided,
|
||||
process_id,
|
||||
last_accessed_timestamp: AtomicU64::new(timestamp),
|
||||
extra: Box::new(ConnectionExtra {
|
||||
direction,
|
||||
end_timestamp: 0,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for ConnectionV6 {
|
||||
fn remote_equals(&self, key: &Key) -> bool {
|
||||
if self.remote_port != key.remote_port {
|
||||
return false;
|
||||
}
|
||||
if let IpAddress::Ipv6(remote_address) = &key.remote_address {
|
||||
return self.remote_address.eq(remote_address);
|
||||
}
|
||||
false
|
||||
}
|
||||
fn get_key(&self) -> Key {
|
||||
Key {
|
||||
protocol: self.protocol,
|
||||
local_address: IpAddress::Ipv6(self.local_address),
|
||||
local_port: self.local_port,
|
||||
remote_address: IpAddress::Ipv6(self.remote_address),
|
||||
remote_port: self.remote_port,
|
||||
}
|
||||
}
|
||||
|
||||
fn redirect_equals(&self, key: &Key) -> bool {
|
||||
match self.verdict {
|
||||
Verdict::RedirectNameServer => {
|
||||
if key.remote_port != PM_DNS_PORT {
|
||||
return false;
|
||||
}
|
||||
|
||||
match key.remote_address {
|
||||
IpAddress::Ipv4(_) => false,
|
||||
IpAddress::Ipv6(a) => a.is_loopback(),
|
||||
}
|
||||
}
|
||||
Verdict::RedirectTunnel => {
|
||||
if key.remote_port != PM_SPN_PORT {
|
||||
return false;
|
||||
}
|
||||
key.local_address.eq(&key.remote_address)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_protocol(&self) -> IpProtocol {
|
||||
self.protocol
|
||||
}
|
||||
|
||||
fn get_verdict(&self) -> Verdict {
|
||||
self.verdict
|
||||
}
|
||||
|
||||
fn get_local_address(&self) -> IpAddress {
|
||||
IpAddress::Ipv6(self.local_address)
|
||||
}
|
||||
|
||||
fn get_local_port(&self) -> u16 {
|
||||
self.local_port
|
||||
}
|
||||
|
||||
fn get_remote_address(&self) -> IpAddress {
|
||||
IpAddress::Ipv6(self.remote_address)
|
||||
}
|
||||
|
||||
fn get_remote_port(&self) -> u16 {
|
||||
self.remote_port
|
||||
}
|
||||
|
||||
fn is_ipv6(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_process_id(&self) -> u64 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
fn get_direction(&self) -> Direction {
|
||||
self.extra.direction
|
||||
}
|
||||
|
||||
fn end(&mut self, timestamp: u64) {
|
||||
self.extra.end_timestamp = timestamp;
|
||||
}
|
||||
|
||||
fn get_end_time(&self) -> u64 {
|
||||
self.extra.end_timestamp
|
||||
}
|
||||
|
||||
fn get_last_accessed_time(&self) -> u64 {
|
||||
self.last_accessed_timestamp.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn set_last_accessed_time(&self, timestamp: u64) {
|
||||
self.last_accessed_timestamp
|
||||
.store(timestamp, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ConnectionV6 {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
protocol: self.protocol,
|
||||
local_address: self.local_address,
|
||||
local_port: self.local_port,
|
||||
remote_address: self.remote_address,
|
||||
remote_port: self.remote_port,
|
||||
verdict: self.verdict,
|
||||
process_id: self.process_id,
|
||||
last_accessed_timestamp: AtomicU64::new(
|
||||
self.last_accessed_timestamp.load(Ordering::Relaxed),
|
||||
),
|
||||
extra: self.extra.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
200
windows_kext/driver/src/connection_cache.rs
Normal file
200
windows_kext/driver/src/connection_cache.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
use core::time::Duration;
|
||||
|
||||
use crate::{
|
||||
connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict},
|
||||
connection_map::{ConnectionMap, Key},
|
||||
};
|
||||
use alloc::{format, string::String, vec::Vec};
|
||||
|
||||
use smoltcp::wire::IpProtocol;
|
||||
use wdk::rw_spin_lock::RwSpinLock;
|
||||
|
||||
pub struct ConnectionCache {
|
||||
connections_v4: ConnectionMap<ConnectionV4>,
|
||||
connections_v6: ConnectionMap<ConnectionV6>,
|
||||
lock_v4: RwSpinLock,
|
||||
lock_v6: RwSpinLock,
|
||||
}
|
||||
|
||||
impl ConnectionCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections_v4: ConnectionMap::new(),
|
||||
connections_v6: ConnectionMap::new(),
|
||||
lock_v4: RwSpinLock::default(),
|
||||
lock_v6: RwSpinLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_connection_v4(&mut self, connection: ConnectionV4) {
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
self.connections_v4.add(connection);
|
||||
}
|
||||
|
||||
pub fn add_connection_v6(&mut self, connection: ConnectionV6) {
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
self.connections_v6.add(connection);
|
||||
}
|
||||
|
||||
pub fn update_connection(&mut self, key: Key, verdict: Verdict) -> Option<RedirectInfo> {
|
||||
if key.is_ipv6() {
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
if let Some(conn) = self.connections_v6.get_mut(&key) {
|
||||
conn.verdict = verdict;
|
||||
return conn.redirect_info();
|
||||
}
|
||||
} else {
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
if let Some(conn) = self.connections_v4.get_mut(&key) {
|
||||
conn.verdict = verdict;
|
||||
return conn.redirect_info();
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn read_connection_v4<T>(
|
||||
&self,
|
||||
key: &Key,
|
||||
process_connection: fn(&ConnectionV4) -> Option<T>,
|
||||
) -> Option<T> {
|
||||
let _guard = self.lock_v4.read_lock();
|
||||
self.connections_v4.read(&key, process_connection)
|
||||
}
|
||||
|
||||
pub fn read_connection_v6<T>(
|
||||
&self,
|
||||
key: &Key,
|
||||
process_connection: fn(&ConnectionV6) -> Option<T>,
|
||||
) -> Option<T> {
|
||||
let _guard = self.lock_v6.read_lock();
|
||||
self.connections_v6.read(&key, process_connection)
|
||||
}
|
||||
|
||||
pub fn end_connection_v4(&mut self, key: Key) -> Option<ConnectionV4> {
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
self.connections_v4.end(key)
|
||||
}
|
||||
|
||||
pub fn end_connection_v6(&mut self, key: Key) -> Option<ConnectionV6> {
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
self.connections_v6.end(key)
|
||||
}
|
||||
|
||||
pub fn end_all_on_port_v4(&mut self, key: (IpProtocol, u16)) -> Option<Vec<ConnectionV4>> {
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
self.connections_v4.end_all_on_port(key)
|
||||
}
|
||||
|
||||
pub fn end_all_on_port_v6(&mut self, key: (IpProtocol, u16)) -> Option<Vec<ConnectionV6>> {
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
self.connections_v6.end_all_on_port(key)
|
||||
}
|
||||
|
||||
pub fn clean_ended_connections(&mut self) {
|
||||
{
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
self.connections_v4.clean_ended_connections();
|
||||
}
|
||||
{
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
self.connections_v6.clean_ended_connections();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
{
|
||||
let _guard = self.lock_v4.write_lock();
|
||||
self.connections_v4.clear();
|
||||
}
|
||||
{
|
||||
let _guard = self.lock_v6.write_lock();
|
||||
self.connections_v6.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_entries_count(&self) -> usize {
|
||||
let mut size = 0;
|
||||
{
|
||||
let _guard = self.lock_v4.read_lock();
|
||||
size += self.connections_v4.get_count();
|
||||
}
|
||||
|
||||
{
|
||||
let _guard = self.lock_v6.read_lock();
|
||||
size += self.connections_v6.get_count();
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_full_cache_info(&self) -> String {
|
||||
let mut info = String::new();
|
||||
let now = wdk::utils::get_system_timestamp_ms();
|
||||
{
|
||||
let _guard = self.lock_v4.read_lock();
|
||||
for ((protocol, port), connections) in self.connections_v4.iter() {
|
||||
info.push_str(&format!("{} -> {}\n", protocol, port,));
|
||||
for conn in connections {
|
||||
let active_time_seconds =
|
||||
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
|
||||
info.push_str(&format!(
|
||||
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
|
||||
conn.local_address,
|
||||
conn.local_port,
|
||||
conn.remote_address,
|
||||
conn.remote_port,
|
||||
conn.verdict,
|
||||
active_time_seconds / 60,
|
||||
active_time_seconds % 60
|
||||
));
|
||||
if conn.has_ended() {
|
||||
let end_time_seconds =
|
||||
Duration::from_millis(now - conn.get_end_time()).as_secs();
|
||||
info.push_str(&format!(
|
||||
"\t ended {}m {}s ago",
|
||||
end_time_seconds / 60,
|
||||
end_time_seconds % 60
|
||||
));
|
||||
}
|
||||
info.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let _guard = self.lock_v6.read_lock();
|
||||
for ((protocol, port), connections) in self.connections_v6.iter() {
|
||||
info.push_str(&format!("{} -> {} \n", protocol, port));
|
||||
for conn in connections {
|
||||
let active_time_seconds =
|
||||
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
|
||||
info.push_str(&format!(
|
||||
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
|
||||
conn.local_address,
|
||||
conn.local_port,
|
||||
conn.remote_address,
|
||||
conn.remote_port,
|
||||
conn.verdict,
|
||||
active_time_seconds / 60,
|
||||
active_time_seconds % 60
|
||||
));
|
||||
if conn.has_ended() {
|
||||
let end_time_seconds =
|
||||
Duration::from_millis(now - conn.get_end_time()).as_secs();
|
||||
info.push_str(&format!(
|
||||
"\t ended {}m {}s ago",
|
||||
end_time_seconds / 60,
|
||||
end_time_seconds % 60
|
||||
));
|
||||
}
|
||||
info.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
}
|
||||
179
windows_kext/driver/src/connection_map.rs
Normal file
179
windows_kext/driver/src/connection_map.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use core::{fmt::Display, time::Duration};
|
||||
|
||||
use crate::connection::Connection;
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::HashMap;
|
||||
use smoltcp::wire::{IpAddress, IpProtocol};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
|
||||
pub struct Key {
|
||||
pub(crate) protocol: IpProtocol,
|
||||
pub(crate) local_address: IpAddress,
|
||||
pub(crate) local_port: u16,
|
||||
pub(crate) remote_address: IpAddress,
|
||||
pub(crate) remote_port: u16,
|
||||
}
|
||||
|
||||
impl Display for Key {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"p: {} l: {}:{} r: {}:{}",
|
||||
self.protocol,
|
||||
self.local_address,
|
||||
self.local_port,
|
||||
self.remote_address,
|
||||
self.remote_port
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Key {
|
||||
/// Returns the protocol and port as a tuple.
|
||||
pub fn small(&self) -> (IpProtocol, u16) {
|
||||
(self.protocol, self.local_port)
|
||||
}
|
||||
|
||||
/// Returns true if the local address is an IPv4 address.
|
||||
pub fn is_ipv6(&self) -> bool {
|
||||
match self.local_address {
|
||||
IpAddress::Ipv4(_) => false,
|
||||
IpAddress::Ipv6(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the local address is a loopback address.
|
||||
pub fn is_loopback(&self) -> bool {
|
||||
match self.local_address {
|
||||
IpAddress::Ipv4(ip) => ip.is_loopback(),
|
||||
IpAddress::Ipv6(ip) => ip.is_loopback(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new key with the local and remote addresses and ports reversed.
|
||||
#[allow(dead_code)]
|
||||
pub fn reverse(&self) -> Key {
|
||||
Key {
|
||||
protocol: self.protocol,
|
||||
local_address: self.remote_address,
|
||||
local_port: self.remote_port,
|
||||
remote_address: self.local_address,
|
||||
remote_port: self.local_port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
|
||||
|
||||
impl<T: Connection + Clone> ConnectionMap<T> {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
pub fn add(&mut self, conn: T) {
|
||||
let key = conn.get_key().small();
|
||||
if let Some(connections) = self.0.get_mut(&key) {
|
||||
connections.push(conn);
|
||||
} else {
|
||||
self.0.insert(key, alloc::vec![conn]);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, key: &Key) -> Option<&mut T> {
|
||||
if let Some(connections) = self.0.get_mut(&key.small()) {
|
||||
for conn in connections {
|
||||
if conn.remote_equals(key) {
|
||||
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
||||
return Some(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn read<C>(&self, key: &Key, read_connection: fn(&T) -> Option<C>) -> Option<C> {
|
||||
if let Some(connections) = self.0.get(&key.small()) {
|
||||
for conn in connections {
|
||||
if conn.remote_equals(key) {
|
||||
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
||||
return read_connection(conn);
|
||||
}
|
||||
if conn.redirect_equals(key) {
|
||||
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
||||
return read_connection(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn end(&mut self, key: Key) -> Option<T> {
|
||||
if let Some(connections) = self.0.get_mut(&key.small()) {
|
||||
for (_, conn) in connections.iter_mut().enumerate() {
|
||||
if conn.remote_equals(&key) {
|
||||
conn.end(wdk::utils::get_system_timestamp_ms());
|
||||
return Some(conn.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
pub fn end_all_on_port(&mut self, key: (IpProtocol, u16)) -> Option<Vec<T>> {
|
||||
if let Some(connections) = self.0.get_mut(&key) {
|
||||
let mut vec = Vec::with_capacity(connections.len());
|
||||
for (_, conn) in connections.iter_mut().enumerate() {
|
||||
if !conn.has_ended() {
|
||||
conn.end(wdk::utils::get_system_timestamp_ms());
|
||||
vec.push(conn.clone());
|
||||
}
|
||||
}
|
||||
return Some(vec);
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.0.clear();
|
||||
}
|
||||
|
||||
pub fn clean_ended_connections(&mut self) {
|
||||
let now = wdk::utils::get_system_timestamp_ms();
|
||||
const TEN_MINUETS: u64 = Duration::from_secs(60 * 10).as_millis() as u64;
|
||||
let before_ten_minutes = now - TEN_MINUETS;
|
||||
let before_one_minute = now - Duration::from_secs(60).as_millis() as u64;
|
||||
|
||||
for (_, connections) in self.0.iter_mut() {
|
||||
connections.retain(|c| {
|
||||
if c.has_ended() && c.get_end_time() < before_one_minute {
|
||||
// Ended more than 1 minute ago
|
||||
return false;
|
||||
}
|
||||
|
||||
if c.get_last_accessed_time() < before_ten_minutes {
|
||||
// Last active more than 10 minutes ago
|
||||
return false;
|
||||
}
|
||||
|
||||
// Keep
|
||||
return true;
|
||||
});
|
||||
}
|
||||
self.0.retain(|_, v| !v.is_empty());
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_count(&self) -> usize {
|
||||
let mut count = 0;
|
||||
for conn in self.0.values() {
|
||||
count += conn.len();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
329
windows_kext/driver/src/device.rs
Normal file
329
windows_kext/driver/src/device.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use alloc::string::String;
|
||||
use num_traits::FromPrimitive;
|
||||
use protocol::{command::CommandType, info::Info};
|
||||
use smoltcp::wire::{IpAddress, IpProtocol, Ipv4Address, Ipv6Address};
|
||||
use wdk::{
|
||||
driver::Driver,
|
||||
filter_engine::{
|
||||
callout_data::ClassifyDefer,
|
||||
net_buffer::{NetBufferList, NetworkAllocator},
|
||||
packet::{InjectInfo, Injector},
|
||||
FilterEngine,
|
||||
},
|
||||
ioqueue::{self, IOQueue},
|
||||
irp_helpers::{ReadRequest, WriteRequest},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
array_holder::ArrayHolder, bandwidth::Bandwidth, callouts, connection_cache::ConnectionCache,
|
||||
connection_map::Key, dbg, err, id_cache::IdCache, logger, packet_util::Redirect,
|
||||
};
|
||||
|
||||
pub enum Packet {
|
||||
PacketLayer(NetBufferList, InjectInfo),
|
||||
AleLayer(ClassifyDefer),
|
||||
}
|
||||
|
||||
// Device Context
|
||||
pub struct Device {
|
||||
pub(crate) filter_engine: FilterEngine,
|
||||
pub(crate) read_leftover: ArrayHolder,
|
||||
pub(crate) event_queue: IOQueue<Info>,
|
||||
pub(crate) packet_cache: IdCache,
|
||||
pub(crate) connection_cache: ConnectionCache,
|
||||
pub(crate) injector: Injector,
|
||||
pub(crate) network_allocator: NetworkAllocator,
|
||||
pub(crate) bandwidth_stats: Bandwidth,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
/// Initialize all members of the device. Memory is handled by windows.
|
||||
/// Make sure everything is initialized here.
|
||||
pub fn new(driver: &Driver) -> Result<Self, String> {
|
||||
let mut filter_engine =
|
||||
match FilterEngine::new(driver, 0x7dab1057_8e2b_40c4_9b85_693e381d7896) {
|
||||
Ok(fe) => fe,
|
||||
Err(err) => return Err(alloc::format!("filter engine error: {}", err)),
|
||||
};
|
||||
|
||||
if let Err(err) = filter_engine.commit(callouts::get_callout_vec()) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
filter_engine,
|
||||
read_leftover: ArrayHolder::default(),
|
||||
event_queue: IOQueue::new(),
|
||||
packet_cache: IdCache::new(),
|
||||
connection_cache: ConnectionCache::new(),
|
||||
injector: Injector::new(),
|
||||
network_allocator: NetworkAllocator::new(),
|
||||
bandwidth_stats: Bandwidth::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Cleanup is called just before drop.
|
||||
// pub fn cleanup(&mut self) {}
|
||||
|
||||
fn write_buffer(&mut self, read_request: &mut ReadRequest, info: Info) {
|
||||
let bytes = info.as_bytes();
|
||||
let count = read_request.write(bytes);
|
||||
|
||||
// Check if the full buffer was written.
|
||||
if count < bytes.len() {
|
||||
// Save the leftovers for later.
|
||||
self.read_leftover.save(&bytes[count..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called when handle. Read is called from user-space.
|
||||
pub fn read(&mut self, read_request: &mut ReadRequest) {
|
||||
if let Some(data) = self.read_leftover.load() {
|
||||
// There are leftovers from previous request.
|
||||
let count = read_request.write(&data);
|
||||
|
||||
// Check if full command was written.
|
||||
if count < data.len() {
|
||||
// Save the leftovers for later.
|
||||
self.read_leftover.save(&data[count..]);
|
||||
}
|
||||
} else {
|
||||
// Noting left from before. Wait for next commands.
|
||||
match self.event_queue.wait_and_pop() {
|
||||
Ok(info) => {
|
||||
self.write_buffer(read_request, info);
|
||||
}
|
||||
Err(ioqueue::Status::Timeout) => {
|
||||
// Timeout. This will only trigger if pop function is called with timeout.
|
||||
read_request.timeout();
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
// Queue failed. Send EOF, to notify user-space. Usually happens on rundown.
|
||||
err!("failed to pop value: {}", err);
|
||||
read_request.end_of_file();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have more space. InfoType + data_size == 5 bytes
|
||||
while read_request.free_space() > 5 {
|
||||
match self.event_queue.pop() {
|
||||
Ok(info) => {
|
||||
self.write_buffer(read_request, info);
|
||||
}
|
||||
Err(_) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
read_request.complete();
|
||||
}
|
||||
|
||||
// Called when handle.Write is called from user-space.
|
||||
pub fn write(&mut self, write_request: &mut WriteRequest) {
|
||||
// Try parsing the command.
|
||||
let mut buffer = write_request.get_buffer();
|
||||
let command = protocol::command::parse_type(buffer);
|
||||
let Some(command) = command else {
|
||||
err!("Unknown command number: {}", buffer[0]);
|
||||
return;
|
||||
};
|
||||
buffer = &buffer[1..];
|
||||
|
||||
let mut _classify_defer = None;
|
||||
|
||||
match command {
|
||||
CommandType::Shutdown => {
|
||||
wdk::dbg!("Shutdown command");
|
||||
self.shutdown();
|
||||
}
|
||||
CommandType::Verdict => {
|
||||
let verdict = protocol::command::parse_verdict(buffer);
|
||||
wdk::dbg!("Verdict command");
|
||||
// Received verdict decision for a specific connection.
|
||||
if let Some((key, mut packet)) = self.packet_cache.pop_id(verdict.id) {
|
||||
if let Some(verdict) = FromPrimitive::from_u8(verdict.verdict) {
|
||||
dbg!("Verdict received {}: {}", key, verdict);
|
||||
// Add verdict in the cache.
|
||||
let redirect_info = self.connection_cache.update_connection(key, verdict);
|
||||
|
||||
// if verdict.is_permanent() {
|
||||
// dbg!(self.logger, "resetting filters {}: {}", key, verdict);
|
||||
// _ = self.filter_engine.reset_all_filters();
|
||||
// }
|
||||
|
||||
match verdict {
|
||||
crate::connection::Verdict::Accept
|
||||
| crate::connection::Verdict::PermanentAccept => {
|
||||
if let Err(err) = self.inject_packet(packet, false) {
|
||||
err!("failed to inject packet: {}", err);
|
||||
} else {
|
||||
dbg!("packet injected: {}", key);
|
||||
}
|
||||
}
|
||||
crate::connection::Verdict::RedirectNameServer
|
||||
| crate::connection::Verdict::RedirectTunnel => {
|
||||
if let Some(redirect_info) = redirect_info {
|
||||
if let Err(err) = packet.redirect(redirect_info) {
|
||||
err!("failed to redirect packet: {}", err);
|
||||
}
|
||||
if let Err(err) = self.inject_packet(packet, false) {
|
||||
err!("failed to inject packet: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let Err(err) = self.inject_packet(packet, true) {
|
||||
err!("failed to inject packet: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// Id was not in the packet cache.
|
||||
let id = verdict.id;
|
||||
err!("Verdict invalid id: {}", id);
|
||||
}
|
||||
}
|
||||
CommandType::UpdateV4 => {
|
||||
let update = protocol::command::parse_update_v4(buffer);
|
||||
// Build the new action.
|
||||
if let Some(verdict) = FromPrimitive::from_u8(update.verdict) {
|
||||
// Update with new action.
|
||||
dbg!("Verdict update received {:?}: {}", update, verdict);
|
||||
_classify_defer = self.connection_cache.update_connection(
|
||||
Key {
|
||||
protocol: IpProtocol::from(update.protocol),
|
||||
local_address: IpAddress::Ipv4(Ipv4Address::from_bytes(
|
||||
&update.local_address,
|
||||
)),
|
||||
local_port: update.local_port,
|
||||
remote_address: IpAddress::Ipv4(Ipv4Address::from_bytes(
|
||||
&update.remote_address,
|
||||
)),
|
||||
remote_port: update.remote_port,
|
||||
},
|
||||
verdict,
|
||||
);
|
||||
} else {
|
||||
err!("invalid verdict value: {}", update.verdict);
|
||||
}
|
||||
}
|
||||
CommandType::UpdateV6 => {
|
||||
let update = protocol::command::parse_update_v6(buffer);
|
||||
// Build the new action.
|
||||
if let Some(verdict) = FromPrimitive::from_u8(update.verdict) {
|
||||
// Update with new action.
|
||||
dbg!("Verdict update received {:?}: {}", update, verdict);
|
||||
_classify_defer = self.connection_cache.update_connection(
|
||||
Key {
|
||||
protocol: IpProtocol::from(update.protocol),
|
||||
local_address: IpAddress::Ipv6(Ipv6Address::from_bytes(
|
||||
&update.local_address,
|
||||
)),
|
||||
local_port: update.local_port,
|
||||
remote_address: IpAddress::Ipv6(Ipv6Address::from_bytes(
|
||||
&update.remote_address,
|
||||
)),
|
||||
remote_port: update.remote_port,
|
||||
},
|
||||
verdict,
|
||||
);
|
||||
} else {
|
||||
err!("invalid verdict value: {}", update.verdict);
|
||||
}
|
||||
}
|
||||
CommandType::ClearCache => {
|
||||
wdk::dbg!("ClearCache command");
|
||||
self.connection_cache.clear();
|
||||
if let Err(err) = self.filter_engine.reset_all_filters() {
|
||||
err!("failed to reset filters: {}", err);
|
||||
}
|
||||
}
|
||||
CommandType::GetLogs => {
|
||||
wdk::dbg!("GetLogs command");
|
||||
let lines_vec = logger::flush();
|
||||
for line in lines_vec {
|
||||
let _ = self.event_queue.push(line);
|
||||
}
|
||||
}
|
||||
CommandType::GetBandwidthStats => {
|
||||
wdk::dbg!("GetBandwidthStats command");
|
||||
let stats = self.bandwidth_stats.get_all_updates_tcp_v4();
|
||||
if let Some(stats) = stats {
|
||||
_ = self.event_queue.push(stats);
|
||||
}
|
||||
|
||||
let stats = self.bandwidth_stats.get_all_updates_tcp_v6();
|
||||
if let Some(stats) = stats {
|
||||
_ = self.event_queue.push(stats);
|
||||
}
|
||||
|
||||
let stats = self.bandwidth_stats.get_all_updates_udp_v4();
|
||||
if let Some(stats) = stats {
|
||||
_ = self.event_queue.push(stats);
|
||||
}
|
||||
|
||||
let stats = self.bandwidth_stats.get_all_updates_udp_v6();
|
||||
if let Some(stats) = stats {
|
||||
_ = self.event_queue.push(stats);
|
||||
}
|
||||
}
|
||||
CommandType::PrintMemoryStats => {
|
||||
// Getting the information takes a long time and interferes with the callouts causing the device to crash.
|
||||
// TODO(vladimir): Make more optimized version
|
||||
// info!(
|
||||
// "Packet cache: {} entries",
|
||||
// self.packet_cache.get_entries_count()
|
||||
// );
|
||||
// info!(
|
||||
// "BandwidthStats cache: {} entries",
|
||||
// self.bandwidth_stats.get_entries_count()
|
||||
// );
|
||||
// info!(
|
||||
// "Connection cache: {} entries\n {}",
|
||||
// self.connection_cache.get_entries_count(),
|
||||
// self.connection_cache.get_full_cache_info()
|
||||
// );
|
||||
}
|
||||
CommandType::CleanEndedConnections => {
|
||||
wdk::dbg!("CleanEndedConnections command");
|
||||
self.connection_cache.clean_ended_connections();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
// End blocking operations from the queue. This will end pending read requests.
|
||||
self.event_queue.rundown();
|
||||
}
|
||||
|
||||
pub fn inject_packet(&mut self, packet: Packet, blocked: bool) -> Result<(), String> {
|
||||
match packet {
|
||||
Packet::PacketLayer(nbl, inject_info) => {
|
||||
if !blocked {
|
||||
self.injector.inject_net_buffer_list(nbl, inject_info)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Packet::AleLayer(defer) => {
|
||||
let packet_list = defer.complete(&mut self.filter_engine)?;
|
||||
if let Some(packet_list) = packet_list {
|
||||
self.injector.inject_packet_list_transport(packet_list)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Device {
|
||||
fn drop(&mut self) {
|
||||
_ = logger::flush();
|
||||
// dbg!("Device Context drop called.");
|
||||
}
|
||||
}
|
||||
25
windows_kext/driver/src/driver_hashmap.rs
Normal file
25
windows_kext/driver/src/driver_hashmap.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use hashbrown::HashMap;
|
||||
|
||||
pub struct DeviceHashMap<Key, Value>(Option<HashMap<Key, Value>>);
|
||||
|
||||
impl<Key, Value> DeviceHashMap<Key, Value> {
|
||||
pub fn new() -> Self {
|
||||
Self(Some(HashMap::new()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> Deref for DeviceHashMap<Key, Value> {
|
||||
type Target = HashMap<Key, Value>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> DerefMut for DeviceHashMap<Key, Value> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
135
windows_kext/driver/src/entry.rs
Normal file
135
windows_kext/driver/src/entry.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use crate::common::ControlCode;
|
||||
use crate::device;
|
||||
use alloc::boxed::Box;
|
||||
use num_traits::FromPrimitive;
|
||||
use wdk::irp_helpers::{DeviceControlRequest, ReadRequest, WriteRequest};
|
||||
use wdk::{err, info, interface};
|
||||
use windows_sys::Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP};
|
||||
use windows_sys::Win32::Foundation::{NTSTATUS, STATUS_SUCCESS};
|
||||
|
||||
static VERSION: [u8; 4] = include!("../../kext_interface/version.txt");
|
||||
|
||||
static mut DEVICE: *mut device::Device = core::ptr::null_mut();
|
||||
pub fn get_device() -> Option<&'static mut device::Device> {
|
||||
return unsafe { DEVICE.as_mut() };
|
||||
}
|
||||
|
||||
// DriverEntry is the entry point of the driver (main function). Will be called when driver is loaded.
|
||||
// Name should not be changed
|
||||
#[export_name = "DriverEntry"]
|
||||
pub extern "system" fn driver_entry(
|
||||
driver_object: *mut windows_sys::Wdk::Foundation::DRIVER_OBJECT,
|
||||
registry_path: *mut windows_sys::Win32::Foundation::UNICODE_STRING,
|
||||
) -> windows_sys::Win32::Foundation::NTSTATUS {
|
||||
info!("Starting initialization...");
|
||||
|
||||
// Initialize driver object.
|
||||
let mut driver = match interface::init_driver_object(
|
||||
driver_object,
|
||||
registry_path,
|
||||
"PortmasterKext",
|
||||
core::ptr::null_mut(),
|
||||
) {
|
||||
Ok(driver) => driver,
|
||||
Err(status) => {
|
||||
err!("driver_entry: failed to initialize driver: {}", status);
|
||||
return windows_sys::Win32::Foundation::STATUS_FAILED_DRIVER_ENTRY;
|
||||
}
|
||||
};
|
||||
|
||||
// Set driver functions.
|
||||
driver.set_driver_unload(driver_unload);
|
||||
driver.set_read_fn(driver_read);
|
||||
driver.set_write_fn(driver_write);
|
||||
driver.set_device_control_fn(device_control);
|
||||
|
||||
// Initialize device.
|
||||
unsafe {
|
||||
let device = match device::Device::new(&driver) {
|
||||
Ok(device) => Box::new(device),
|
||||
Err(err) => {
|
||||
wdk::err!("filed to initialize device: {}", err);
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
DEVICE = Box::into_raw(device);
|
||||
}
|
||||
|
||||
STATUS_SUCCESS
|
||||
}
|
||||
|
||||
// driver_unload function is called when service delete is called from user-space.
|
||||
unsafe extern "system" fn driver_unload(_object: *const DRIVER_OBJECT) {
|
||||
info!("Unloading complete");
|
||||
unsafe {
|
||||
if !DEVICE.is_null() {
|
||||
_ = Box::from_raw(DEVICE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// driver_read event triggered from user-space on file.Read.
|
||||
unsafe extern "system" fn driver_read(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut read_request = ReadRequest::new(irp);
|
||||
let Some(device) = get_device() else {
|
||||
read_request.complete();
|
||||
|
||||
return read_request.get_status();
|
||||
};
|
||||
|
||||
device.read(&mut read_request);
|
||||
read_request.get_status()
|
||||
}
|
||||
|
||||
/// driver_write event triggered from user-space on file.Write.
|
||||
unsafe extern "system" fn driver_write(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut write_request = WriteRequest::new(irp);
|
||||
let Some(device) = get_device() else {
|
||||
write_request.complete();
|
||||
return write_request.get_status();
|
||||
};
|
||||
|
||||
device.write(&mut write_request);
|
||||
|
||||
write_request.mark_all_as_read();
|
||||
write_request.complete();
|
||||
write_request.get_status()
|
||||
}
|
||||
|
||||
/// device_control event triggered from user-space on file.deviceIOControl.
|
||||
unsafe extern "system" fn device_control(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut control_request = DeviceControlRequest::new(irp);
|
||||
let Some(device) = get_device() else {
|
||||
control_request.complete();
|
||||
return control_request.get_status();
|
||||
};
|
||||
|
||||
let Some(control_code): Option<ControlCode> =
|
||||
FromPrimitive::from_u32(control_request.get_control_code())
|
||||
else {
|
||||
wdk::info!("Unknown IOCTL code: {}", control_request.get_control_code());
|
||||
control_request.not_implemented();
|
||||
return control_request.get_status();
|
||||
};
|
||||
|
||||
wdk::info!("IOCTL: {}", control_code);
|
||||
|
||||
match control_code {
|
||||
ControlCode::Version => {
|
||||
control_request.write(&VERSION);
|
||||
}
|
||||
ControlCode::ShutdownRequest => device.shutdown(),
|
||||
};
|
||||
|
||||
control_request.complete();
|
||||
control_request.get_status()
|
||||
}
|
||||
131
windows_kext/driver/src/id_cache.rs
Normal file
131
windows_kext/driver/src/id_cache.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use alloc::collections::VecDeque;
|
||||
use protocol::info::Info;
|
||||
use smoltcp::wire::{IpAddress, IpProtocol};
|
||||
use wdk::rw_spin_lock::RwSpinLock;
|
||||
|
||||
use crate::{connection::Direction, connection_map::Key, device::Packet};
|
||||
|
||||
struct Entry<T> {
|
||||
value: T,
|
||||
id: u64,
|
||||
}
|
||||
|
||||
pub struct IdCache {
|
||||
values: VecDeque<Entry<(Key, Packet)>>,
|
||||
lock: RwSpinLock,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl IdCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
values: VecDeque::with_capacity(1000),
|
||||
lock: RwSpinLock::default(),
|
||||
next_id: 1, // 0 is invalid id
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(
|
||||
&mut self,
|
||||
value: (Key, Packet),
|
||||
process_id: u64,
|
||||
direction: Direction,
|
||||
ale_layer: bool,
|
||||
) -> Option<Info> {
|
||||
let _guard = self.lock.write_lock();
|
||||
let id = self.next_id;
|
||||
let info = build_info(&value.0, id, process_id, direction, &value.1, ale_layer);
|
||||
self.values.push_back(Entry { value, id });
|
||||
self.next_id = self.next_id.wrapping_add(1); // Assuming this will not overflow.
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
pub fn pop_id(&mut self, id: u64) -> Option<(Key, Packet)> {
|
||||
let _guard = self.lock.write_lock();
|
||||
if let Ok(index) = self.values.binary_search_by_key(&id, |val| val.id) {
|
||||
return Some(self.values.remove(index).unwrap().value);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_entries_count(&self) -> usize {
|
||||
let _guard = self.lock.read_lock();
|
||||
return self.values.len();
|
||||
}
|
||||
}
|
||||
|
||||
fn get_payload<'a>(packet: &'a Packet) -> Option<&'a [u8]> {
|
||||
match packet {
|
||||
Packet::PacketLayer(nbl, _) => nbl.get_data(),
|
||||
Packet::AleLayer(defer) => {
|
||||
let p = match defer {
|
||||
wdk::filter_engine::callout_data::ClassifyDefer::Initial(_, p) => p,
|
||||
wdk::filter_engine::callout_data::ClassifyDefer::Reauthorization(_, p) => p,
|
||||
};
|
||||
if let Some(tpl) = p {
|
||||
tpl.net_buffer_list_queue.get_data()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_info(
|
||||
key: &Key,
|
||||
packet_id: u64,
|
||||
process_id: u64,
|
||||
direction: Direction,
|
||||
packet: &Packet,
|
||||
ale_layer: bool,
|
||||
) -> Option<Info> {
|
||||
let (local_port, remote_port) = match key.protocol {
|
||||
IpProtocol::Tcp | IpProtocol::Udp => (key.local_port, key.remote_port),
|
||||
_ => (0, 0),
|
||||
};
|
||||
|
||||
let payload_layer = if ale_layer {
|
||||
4 // Transport layer
|
||||
} else {
|
||||
3 // Network layer
|
||||
};
|
||||
|
||||
let mut payload = &[][..];
|
||||
if let Some(p) = get_payload(packet) {
|
||||
payload = p;
|
||||
}
|
||||
|
||||
match (key.local_address, key.remote_address) {
|
||||
(IpAddress::Ipv6(local_ip), IpAddress::Ipv6(remote_ip)) if key.is_ipv6() => {
|
||||
Some(protocol::info::connection_info_v6(
|
||||
packet_id,
|
||||
process_id,
|
||||
direction as u8,
|
||||
u8::from(key.protocol),
|
||||
local_ip.0,
|
||||
remote_ip.0,
|
||||
local_port,
|
||||
remote_port,
|
||||
payload_layer,
|
||||
payload,
|
||||
))
|
||||
}
|
||||
(IpAddress::Ipv4(local_ip), IpAddress::Ipv4(remote_ip)) => {
|
||||
Some(protocol::info::connection_info_v4(
|
||||
packet_id,
|
||||
process_id,
|
||||
direction as u8,
|
||||
u8::from(key.protocol),
|
||||
local_ip.0,
|
||||
remote_ip.0,
|
||||
local_port,
|
||||
remote_port,
|
||||
payload_layer,
|
||||
payload,
|
||||
))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
43
windows_kext/driver/src/lib.rs
Normal file
43
windows_kext/driver/src/lib.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
#![cfg_attr(not(test), no_std)]
|
||||
#![no_main]
|
||||
#![allow(clippy::needless_return)]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
mod ale_callouts;
|
||||
mod array_holder;
|
||||
mod bandwidth;
|
||||
mod callouts;
|
||||
mod common;
|
||||
mod connection;
|
||||
mod connection_cache;
|
||||
mod connection_map;
|
||||
mod device;
|
||||
mod driver_hashmap;
|
||||
mod entry;
|
||||
mod id_cache;
|
||||
pub mod logger;
|
||||
mod packet_callouts;
|
||||
mod packet_util;
|
||||
mod stream_callouts;
|
||||
|
||||
use wdk::allocator::WindowsAllocator;
|
||||
|
||||
#[cfg(not(test))]
|
||||
use core::panic::PanicInfo;
|
||||
|
||||
// Declaration of the global memory allocator
|
||||
#[global_allocator]
|
||||
static HEAP: WindowsAllocator = WindowsAllocator {};
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn _DllMainCRTStartup() {}
|
||||
|
||||
#[cfg(not(test))]
|
||||
#[panic_handler]
|
||||
fn panic(info: &PanicInfo) -> ! {
|
||||
use wdk::err;
|
||||
|
||||
err!("{}", info);
|
||||
loop {}
|
||||
}
|
||||
114
windows_kext/driver/src/logger.rs
Normal file
114
windows_kext/driver/src/logger.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::vec::Vec;
|
||||
use core::{
|
||||
mem::MaybeUninit,
|
||||
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
|
||||
};
|
||||
use protocol::info::{Info, Severity};
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
pub const LOG_LEVEL: u8 = Severity::Warning as u8;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub const LOG_LEVEL: u8 = Severity::Trace as u8;
|
||||
|
||||
pub const MAX_LOG_LINE_SIZE: usize = 150;
|
||||
|
||||
static mut LOG_LINES: [AtomicPtr<Info>; 1024] = unsafe { MaybeUninit::zeroed().assume_init() };
|
||||
static START_INDEX: AtomicUsize = unsafe { MaybeUninit::zeroed().assume_init() };
|
||||
static END_INDEX: AtomicUsize = unsafe { MaybeUninit::zeroed().assume_init() };
|
||||
|
||||
pub fn add_line(log_line: Info) {
|
||||
let mut index = END_INDEX.fetch_add(1, Ordering::Acquire);
|
||||
unsafe {
|
||||
index %= LOG_LINES.len();
|
||||
let ptr = &mut LOG_LINES[index];
|
||||
let line = Box::new(log_line);
|
||||
let old = ptr.swap(Box::into_raw(line), Ordering::SeqCst);
|
||||
if !old.is_null() {
|
||||
_ = Box::from_raw(old);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush() -> Vec<Info> {
|
||||
let mut vec = Vec::new();
|
||||
let end_index = END_INDEX.load(Ordering::Acquire);
|
||||
let start_index = START_INDEX.load(Ordering::Acquire);
|
||||
if end_index <= start_index {
|
||||
return vec;
|
||||
}
|
||||
unsafe {
|
||||
let count = end_index - start_index;
|
||||
for i in start_index..start_index + count {
|
||||
let index = i % LOG_LINES.len();
|
||||
let ptr = LOG_LINES[index].swap(core::ptr::null_mut(), Ordering::SeqCst);
|
||||
if !ptr.is_null() {
|
||||
vec.push(*Box::from_raw(ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
START_INDEX.store(end_index, Ordering::Release);
|
||||
vec
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! log_internal {
|
||||
($log_line:expr, $($arg:tt)*) => ({
|
||||
use core::fmt::Write;
|
||||
_ = write!($log_line, "{}:{} ", file!(), line!());
|
||||
_ = write!($log_line, $($arg)*);
|
||||
$crate::logger::add_line($log_line);
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! crit {
|
||||
($($arg:tt)*) => ({
|
||||
if protocol::info::Severity::Critical as u8 >= $crate::logger::LOG_LEVEL {
|
||||
let message = alloc::format!($($arg)*);
|
||||
$crate::logger::add_line(protocol::info::Severity::Critical, alloc::format!("{}:{} ", file!(), line!()), message)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! err {
|
||||
($($arg:tt)*) => ({
|
||||
if protocol::info::Severity::Error as u8 >= $crate::logger::LOG_LEVEL {
|
||||
let mut log_line = protocol::info::log_line(protocol::info::Severity::Error, $crate::logger::MAX_LOG_LINE_SIZE);
|
||||
$crate::log_internal!(log_line, $($arg)*);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! warn {
|
||||
($($arg:tt)*) => ({
|
||||
if protocol::info::Severity::Warning as u8 >= $crate::logger::LOG_LEVEL {
|
||||
let mut log_line = protocol::info::log_line(protocol::info::Severity::Warning, $crate::logger::MAX_LOG_LINE_SIZE);
|
||||
$crate::log_internal!(log_line, $($arg)*);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! dbg {
|
||||
($($arg:tt)*) => ({
|
||||
if protocol::info::Severity::Debug as u8 >= $crate::logger::LOG_LEVEL {
|
||||
let mut log_line = protocol::info::log_line(protocol::info::Severity::Debug, $crate::logger::MAX_LOG_LINE_SIZE);
|
||||
$crate::log_internal!(log_line, $($arg)*);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! info {
|
||||
($($arg:tt)*) => ({
|
||||
if protocol::info::Severity::Info as u8 >= $crate::logger::LOG_LEVEL {
|
||||
let mut log_line = protocol::info::log_line(protocol::info::Severity::Info, $crate::logger::MAX_LOG_LINE_SIZE);
|
||||
$crate::log_internal!(log_line, $($arg)*);
|
||||
}
|
||||
});
|
||||
}
|
||||
298
windows_kext/driver/src/packet_callouts.rs
Normal file
298
windows_kext/driver/src/packet_callouts.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use alloc::string::String;
|
||||
use smoltcp::wire::{IPV4_HEADER_LEN, IPV6_HEADER_LEN};
|
||||
use wdk::filter_engine::callout_data::CalloutData;
|
||||
use wdk::filter_engine::layer;
|
||||
use wdk::filter_engine::net_buffer::{NetBufferList, NetBufferListIter};
|
||||
use wdk::filter_engine::packet::InjectInfo;
|
||||
|
||||
use crate::connection::{
|
||||
Connection, ConnectionV4, ConnectionV6, Direction, RedirectInfo, Verdict, PM_DNS_PORT,
|
||||
PM_SPN_PORT,
|
||||
};
|
||||
use crate::connection_cache::ConnectionCache;
|
||||
use crate::connection_map::Key;
|
||||
use crate::device::{Device, Packet};
|
||||
use crate::packet_util::{get_key_from_nbl_v4, get_key_from_nbl_v6, Redirect};
|
||||
use crate::{err, warn};
|
||||
|
||||
// IP packet layers
|
||||
pub fn ip_packet_layer_outbound_v4(data: CalloutData) {
|
||||
type Fields = layer::FieldsOutboundIppacketV4;
|
||||
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
|
||||
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
|
||||
|
||||
ip_packet_layer(
|
||||
data,
|
||||
false,
|
||||
Direction::Outbound,
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn ip_packet_layer_inbound_v4(data: CalloutData) {
|
||||
type Fields = layer::FieldsInboundIppacketV4;
|
||||
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
|
||||
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
|
||||
ip_packet_layer(
|
||||
data,
|
||||
false,
|
||||
Direction::Inbound,
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn ip_packet_layer_outbound_v6(data: CalloutData) {
|
||||
type Fields = layer::FieldsOutboundIppacketV6;
|
||||
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
|
||||
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
|
||||
|
||||
ip_packet_layer(
|
||||
data,
|
||||
true,
|
||||
Direction::Outbound,
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn ip_packet_layer_inbound_v6(data: CalloutData) {
|
||||
type Fields = layer::FieldsInboundIppacketV6;
|
||||
let interface_index = data.get_value_u32(Fields::InterfaceIndex as usize);
|
||||
let sub_interface_index = data.get_value_u32(Fields::SubInterfaceIndex as usize);
|
||||
|
||||
ip_packet_layer(
|
||||
data,
|
||||
true,
|
||||
Direction::Inbound,
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
);
|
||||
}
|
||||
|
||||
struct ConnectionInfo {
|
||||
verdict: Verdict,
|
||||
process_id: u64,
|
||||
redirect_info: Option<RedirectInfo>,
|
||||
}
|
||||
|
||||
impl ConnectionInfo {
|
||||
fn from_connection<T: Connection>(conn: &T) -> Self {
|
||||
ConnectionInfo {
|
||||
verdict: conn.get_verdict(),
|
||||
process_id: conn.get_process_id(),
|
||||
redirect_info: conn.redirect_info(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fast_track_pm_packets(key: &Key, direction: Direction) -> bool {
|
||||
match direction {
|
||||
Direction::Outbound => {
|
||||
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
|
||||
return key.local_address == key.remote_address;
|
||||
}
|
||||
}
|
||||
Direction::Inbound => {
|
||||
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
|
||||
return key.local_address == key.remote_address;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn ip_packet_layer(
|
||||
mut data: CalloutData,
|
||||
ipv6: bool,
|
||||
direction: Direction,
|
||||
interface_index: u32,
|
||||
sub_interface_index: u32,
|
||||
) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
if device
|
||||
.injector
|
||||
.was_network_packet_injected_by_self(data.get_layer_data() as _, ipv6)
|
||||
{
|
||||
data.action_permit();
|
||||
return;
|
||||
}
|
||||
|
||||
for mut nbl in NetBufferListIter::new(data.get_layer_data() as _) {
|
||||
if let Direction::Inbound = direction {
|
||||
// The header is not part of the NBL for incoming packets. Move the beginning of the buffer back so we get access to it.
|
||||
// The NBL will auto advance after it loses scope.
|
||||
if ipv6 {
|
||||
nbl.retreat(IPV6_HEADER_LEN as u32, true);
|
||||
} else {
|
||||
nbl.retreat(IPV4_HEADER_LEN as u32, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Get key from packet.
|
||||
let key = match if ipv6 {
|
||||
get_key_from_nbl_v6(&nbl, direction)
|
||||
} else {
|
||||
get_key_from_nbl_v4(&nbl, direction)
|
||||
} {
|
||||
Ok(key) => key,
|
||||
Err(err) => {
|
||||
warn!("failed to get key from nbl: {}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if fast_track_pm_packets(&key, direction) {
|
||||
data.action_permit();
|
||||
return;
|
||||
}
|
||||
|
||||
let mut is_tmp_verdict = false;
|
||||
let mut process_id = 0;
|
||||
|
||||
if matches!(
|
||||
key.protocol,
|
||||
smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp
|
||||
) {
|
||||
if let Some(mut conn_info) =
|
||||
get_connection_info(&mut device.connection_cache, &key, ipv6)
|
||||
{
|
||||
process_id = conn_info.process_id;
|
||||
// Check if there is action for this connection.
|
||||
match conn_info.verdict {
|
||||
Verdict::Undecided | Verdict::Accept | Verdict::Block | Verdict::Drop => {
|
||||
is_tmp_verdict = true
|
||||
}
|
||||
Verdict::PermanentAccept => data.action_permit(),
|
||||
Verdict::PermanentBlock => data.action_block(),
|
||||
Verdict::Undeterminable | Verdict::PermanentDrop | Verdict::Failed => {
|
||||
data.block_and_absorb()
|
||||
}
|
||||
Verdict::RedirectNameServer | Verdict::RedirectTunnel => {
|
||||
if let Some(redirect_info) = conn_info.redirect_info.take() {
|
||||
match clone_packet(
|
||||
device,
|
||||
nbl,
|
||||
direction,
|
||||
ipv6,
|
||||
key.is_loopback(),
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
) {
|
||||
Ok(mut packet) => {
|
||||
let _ = packet.redirect(redirect_info);
|
||||
if let Err(err) = device.inject_packet(packet, false) {
|
||||
err!("failed to inject packet: {}", err);
|
||||
}
|
||||
}
|
||||
Err(err) => err!("failed to clone packet: {}", err),
|
||||
}
|
||||
}
|
||||
|
||||
// This will block the original packet. Even if injection failed.
|
||||
data.block_and_absorb();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TCP and UDP always need to go through ALE layer first.
|
||||
if matches!(direction, Direction::Inbound) {
|
||||
// If it's an inbound packet and the connection is not found, we need to continue to ALE layer
|
||||
data.action_permit();
|
||||
return;
|
||||
} else {
|
||||
// This happens sometimes. Leave the decision for portmaster. TODO(vladimir): Find out why.
|
||||
err!("Invalid state for: {}", key);
|
||||
is_tmp_verdict = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Every other protocol treat as a tmp verdict.
|
||||
is_tmp_verdict = true;
|
||||
}
|
||||
|
||||
// Clone packet and send to Portmaster if it's a temporary verdict.
|
||||
if is_tmp_verdict {
|
||||
let packet = match clone_packet(
|
||||
device,
|
||||
nbl,
|
||||
direction,
|
||||
ipv6,
|
||||
key.is_loopback(),
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
err!("failed to clone packet: {}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let info = device
|
||||
.packet_cache
|
||||
.push((key, packet), process_id, direction, false);
|
||||
// Send to Portmaster
|
||||
if let Some(info) = info {
|
||||
let _ = device.event_queue.push(info);
|
||||
}
|
||||
data.block_and_absorb();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_packet(
|
||||
device: &mut Device,
|
||||
nbl: NetBufferList,
|
||||
direction: Direction,
|
||||
ipv6: bool,
|
||||
loopback: bool,
|
||||
interface_index: u32,
|
||||
sub_interface_index: u32,
|
||||
) -> Result<Packet, String> {
|
||||
let clone = nbl.clone(&device.network_allocator)?;
|
||||
let inbound = match direction {
|
||||
Direction::Outbound => false,
|
||||
Direction::Inbound => true,
|
||||
};
|
||||
Ok(Packet::PacketLayer(
|
||||
clone,
|
||||
InjectInfo {
|
||||
ipv6,
|
||||
inbound,
|
||||
loopback,
|
||||
interface_index,
|
||||
sub_interface_index,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
fn get_connection_info(
|
||||
connection_cache: &mut ConnectionCache,
|
||||
key: &Key,
|
||||
ipv6: bool,
|
||||
) -> Option<ConnectionInfo> {
|
||||
if ipv6 {
|
||||
let conn_info = connection_cache.read_connection_v6(
|
||||
&key,
|
||||
|conn: &ConnectionV6| -> Option<ConnectionInfo> {
|
||||
// Function is is behind spin lock. Just copy and return.
|
||||
Some(ConnectionInfo::from_connection(conn))
|
||||
},
|
||||
);
|
||||
return conn_info;
|
||||
} else {
|
||||
let conn_info = connection_cache.read_connection_v4(
|
||||
&key,
|
||||
|conn: &ConnectionV4| -> Option<ConnectionInfo> {
|
||||
// Function is is behind spin lock. Just copy and return.
|
||||
Some(ConnectionInfo::from_connection(conn))
|
||||
},
|
||||
);
|
||||
return conn_info;
|
||||
}
|
||||
}
|
||||
399
windows_kext/driver/src/packet_util.rs
Normal file
399
windows_kext/driver/src/packet_util.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
use alloc::string::{String, ToString};
|
||||
use smoltcp::wire::{
|
||||
IpAddress, IpProtocol, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet, TcpPacket, UdpPacket,
|
||||
};
|
||||
use wdk::filter_engine::net_buffer::NetBufferList;
|
||||
|
||||
use crate::connection_map::Key;
|
||||
use crate::device::Packet;
|
||||
use crate::{
|
||||
connection::{Direction, RedirectInfo},
|
||||
dbg, err,
|
||||
};
|
||||
|
||||
/// `Redirect` is a trait that defines a method for redirecting network packets.
|
||||
///
|
||||
/// This trait is used to implement different strategies for redirecting packets,
|
||||
/// depending on the specific requirements of the application.
|
||||
pub trait Redirect {
|
||||
/// Redirects a network packet based on the provided `RedirectInfo`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `redirect_info` - A struct containing information about how to redirect the packet.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(())` if the packet was successfully redirected.
|
||||
/// * `Err(String)` if there was an error redirecting the packet.
|
||||
fn redirect(&mut self, redirect_info: RedirectInfo) -> Result<(), String>;
|
||||
}
|
||||
|
||||
impl Redirect for Packet {
|
||||
fn redirect(&mut self, redirect_info: RedirectInfo) -> Result<(), String> {
|
||||
if let Packet::PacketLayer(nbl, inject_info) = self {
|
||||
let Some(data) = nbl.get_data_mut() else {
|
||||
return Err("trying to redirect immutable NBL".to_string());
|
||||
};
|
||||
|
||||
if inject_info.inbound {
|
||||
redirect_inbound_packet(
|
||||
data,
|
||||
redirect_info.local_address,
|
||||
redirect_info.remote_address,
|
||||
redirect_info.remote_port,
|
||||
)
|
||||
} else {
|
||||
redirect_outbound_packet(
|
||||
data,
|
||||
redirect_info.redirect_address,
|
||||
redirect_info.redirect_port,
|
||||
redirect_info.unify,
|
||||
)
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
// return Err("can't redirect from non packet layer".to_string());
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Redirects an outbound packet to a specified remote address and port.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `packet` - A mutable reference to the packet data.
|
||||
/// * `remote_address` - The IP address to redirect the packet to.
|
||||
/// * `remote_port` - The port to redirect the packet to.
|
||||
/// * `unify` - If true, the source and destination addresses of the packet will be set to the same value.
|
||||
///
|
||||
/// This function modifies the packet in-place to change its destination address and port.
|
||||
/// It also updates the checksums for the IP and transport layer headers.
|
||||
/// If the `unify` parameter is true, it sets the source and destination addresses to be the same.
|
||||
/// If the remote address is a loopback address, it sets the source address to the loopback address.
|
||||
fn redirect_outbound_packet(
|
||||
packet: &mut [u8],
|
||||
remote_address: IpAddress,
|
||||
remote_port: u16,
|
||||
unify: bool,
|
||||
) {
|
||||
match remote_address {
|
||||
IpAddress::Ipv4(remote_address) => {
|
||||
if let Ok(mut ip_packet) = Ipv4Packet::new_checked(packet) {
|
||||
if unify {
|
||||
ip_packet.set_dst_addr(ip_packet.src_addr());
|
||||
} else {
|
||||
ip_packet.set_dst_addr(remote_address);
|
||||
if remote_address.is_loopback() {
|
||||
ip_packet.set_src_addr(Ipv4Address::new(127, 0, 0, 1));
|
||||
}
|
||||
}
|
||||
ip_packet.fill_checksum();
|
||||
let src_addr = ip_packet.src_addr();
|
||||
let dst_addr = ip_packet.dst_addr();
|
||||
if ip_packet.next_header() == IpProtocol::Udp {
|
||||
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
udp_packet.set_dst_port(remote_port);
|
||||
udp_packet
|
||||
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
|
||||
}
|
||||
}
|
||||
if ip_packet.next_header() == IpProtocol::Tcp {
|
||||
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
tcp_packet.set_dst_port(remote_port);
|
||||
tcp_packet
|
||||
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
IpAddress::Ipv6(remote_address) => {
|
||||
if let Ok(mut ip_packet) = Ipv6Packet::new_checked(packet) {
|
||||
ip_packet.set_dst_addr(remote_address);
|
||||
if unify {
|
||||
ip_packet.set_dst_addr(ip_packet.src_addr());
|
||||
} else {
|
||||
ip_packet.set_dst_addr(remote_address);
|
||||
if remote_address.is_loopback() {
|
||||
ip_packet.set_src_addr(Ipv6Address::LOOPBACK);
|
||||
}
|
||||
}
|
||||
let src_addr = ip_packet.src_addr();
|
||||
let dst_addr = ip_packet.dst_addr();
|
||||
if ip_packet.next_header() == IpProtocol::Udp {
|
||||
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
udp_packet.set_dst_port(remote_port);
|
||||
udp_packet
|
||||
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
|
||||
}
|
||||
}
|
||||
if ip_packet.next_header() == IpProtocol::Tcp {
|
||||
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
tcp_packet.set_dst_port(remote_port);
|
||||
tcp_packet
|
||||
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Redirects an inbound packet to a local address.
|
||||
///
|
||||
/// This function takes a mutable reference to a packet and modifies it in place.
|
||||
/// It changes the destination address to the provided local address and the source address
|
||||
/// to the original remote address. It also sets the source port to the original remote port.
|
||||
/// The function handles both IPv4 and IPv6 addresses.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `packet` - A mutable reference to the packet data.
|
||||
/// * `local_address` - The local IP address to redirect the packet to.
|
||||
/// * `original_remote_address` - The original remote IP address of the packet.
|
||||
/// * `original_remote_port` - The original remote port of the packet.
|
||||
///
|
||||
fn redirect_inbound_packet(
|
||||
packet: &mut [u8],
|
||||
local_address: IpAddress,
|
||||
original_remote_address: IpAddress,
|
||||
original_remote_port: u16,
|
||||
) {
|
||||
match local_address {
|
||||
IpAddress::Ipv4(local_address) => {
|
||||
let IpAddress::Ipv4(original_remote_address) = original_remote_address else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Ok(mut ip_packet) = Ipv4Packet::new_checked(packet) {
|
||||
ip_packet.set_dst_addr(local_address);
|
||||
ip_packet.set_src_addr(original_remote_address);
|
||||
ip_packet.fill_checksum();
|
||||
let src_addr = ip_packet.src_addr();
|
||||
let dst_addr = ip_packet.dst_addr();
|
||||
if ip_packet.next_header() == IpProtocol::Udp {
|
||||
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
udp_packet.set_src_port(original_remote_port);
|
||||
udp_packet
|
||||
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
|
||||
}
|
||||
}
|
||||
if ip_packet.next_header() == IpProtocol::Tcp {
|
||||
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
tcp_packet.set_src_port(original_remote_port);
|
||||
tcp_packet
|
||||
.fill_checksum(&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
IpAddress::Ipv6(local_address) => {
|
||||
if let Ok(mut ip_packet) = Ipv6Packet::new_checked(packet) {
|
||||
let IpAddress::Ipv6(original_remote_address) = original_remote_address else {
|
||||
return;
|
||||
};
|
||||
ip_packet.set_dst_addr(local_address);
|
||||
ip_packet.set_src_addr(original_remote_address);
|
||||
let src_addr = ip_packet.src_addr();
|
||||
let dst_addr = ip_packet.dst_addr();
|
||||
if ip_packet.next_header() == IpProtocol::Udp {
|
||||
if let Ok(mut udp_packet) = UdpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
udp_packet.set_src_port(original_remote_port);
|
||||
udp_packet
|
||||
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
|
||||
}
|
||||
}
|
||||
if ip_packet.next_header() == IpProtocol::Tcp {
|
||||
if let Ok(mut tcp_packet) = TcpPacket::new_checked(ip_packet.payload_mut()) {
|
||||
tcp_packet.set_src_port(original_remote_port);
|
||||
tcp_packet
|
||||
.fill_checksum(&IpAddress::Ipv6(src_addr), &IpAddress::Ipv6(dst_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn print_packet(packet: &[u8]) {
|
||||
if let Ok(ip_packet) = Ipv4Packet::new_checked(packet) {
|
||||
if ip_packet.next_header() == IpProtocol::Udp {
|
||||
if let Ok(udp_packet) = UdpPacket::new_checked(ip_packet.payload()) {
|
||||
dbg!("packet {} {}", ip_packet, udp_packet);
|
||||
}
|
||||
}
|
||||
if ip_packet.next_header() == IpProtocol::Tcp {
|
||||
if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
|
||||
dbg!("packet {} {}", ip_packet, tcp_packet);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err!("failed to print packet: invalid ip header: {:?}", packet);
|
||||
}
|
||||
}
|
||||
|
||||
/// This function extracts a key from a given IPv4 network buffer list (NBL).
|
||||
/// The key contains the protocol, local and remote addresses and ports.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `nbl` - A reference to the network buffer list from which the key will be extracted.
|
||||
/// * `direction` - The direction of the packet (Inbound or Outbound).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Key)` - A key containing the protocol, local and remote addresses and ports.
|
||||
/// * `Err(String)` - An error message if the function fails to get net_buffer data.
|
||||
const HEADERS_LEN: usize = smoltcp::wire::IPV4_HEADER_LEN + smoltcp::wire::TCP_HEADER_LEN;
|
||||
|
||||
fn get_ports(packet: &[u8], protocol: smoltcp::wire::IpProtocol) -> (u16, u16) {
|
||||
match protocol {
|
||||
smoltcp::wire::IpProtocol::Tcp => {
|
||||
let tcp_packet = TcpPacket::new_unchecked(packet);
|
||||
(tcp_packet.src_port(), tcp_packet.dst_port())
|
||||
}
|
||||
smoltcp::wire::IpProtocol::Udp => {
|
||||
let udp_packet = UdpPacket::new_unchecked(packet);
|
||||
(udp_packet.src_port(), udp_packet.dst_port())
|
||||
}
|
||||
_ => (0, 0), // No ports for other protocols
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_key_from_nbl_v4(nbl: &NetBufferList, direction: Direction) -> Result<Key, String> {
|
||||
// Get bytes
|
||||
let mut headers = [0; HEADERS_LEN];
|
||||
if nbl.read_bytes(&mut headers).is_err() {
|
||||
return Err("failed to get net_buffer data".to_string());
|
||||
}
|
||||
|
||||
// Parse packet
|
||||
let ip_packet = Ipv4Packet::new_unchecked(&headers);
|
||||
let (src_port, dst_port) = get_ports(
|
||||
&headers[smoltcp::wire::IPV4_HEADER_LEN..],
|
||||
ip_packet.next_header(),
|
||||
);
|
||||
|
||||
// Build key
|
||||
match direction {
|
||||
Direction::Outbound => Ok(Key {
|
||||
protocol: ip_packet.next_header(),
|
||||
local_address: IpAddress::Ipv4(ip_packet.src_addr()),
|
||||
local_port: src_port,
|
||||
remote_address: IpAddress::Ipv4(ip_packet.dst_addr()),
|
||||
remote_port: dst_port,
|
||||
}),
|
||||
Direction::Inbound => Ok(Key {
|
||||
protocol: ip_packet.next_header(),
|
||||
local_address: IpAddress::Ipv4(ip_packet.dst_addr()),
|
||||
local_port: dst_port,
|
||||
remote_address: IpAddress::Ipv4(ip_packet.src_addr()),
|
||||
remote_port: src_port,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// This function extracts a key from a given IPv6 network buffer list (NBL).
|
||||
/// The key contains the protocol, local and remote addresses and ports.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `nbl` - A reference to the network buffer list from which the key will be extracted.
|
||||
/// * `direction` - The direction of the packet (Inbound or Outbound).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Key)` - A key containing the protocol, local and remote addresses and ports.
|
||||
/// * `Err(String)` - An error message if the function fails to get net_buffer data.
|
||||
pub fn get_key_from_nbl_v6(nbl: &NetBufferList, direction: Direction) -> Result<Key, String> {
|
||||
// Get bytes
|
||||
let mut headers = [0; smoltcp::wire::IPV6_HEADER_LEN + smoltcp::wire::TCP_HEADER_LEN];
|
||||
let Ok(()) = nbl.read_bytes(&mut headers) else {
|
||||
return Err("failed to get net_buffer data".to_string());
|
||||
};
|
||||
// Parse packet
|
||||
let ip_packet = Ipv6Packet::new_unchecked(&headers);
|
||||
let (src_port, dst_port) = get_ports(
|
||||
&headers[smoltcp::wire::IPV6_HEADER_LEN..],
|
||||
ip_packet.next_header(),
|
||||
);
|
||||
|
||||
// Build key
|
||||
match direction {
|
||||
Direction::Outbound => Ok(Key {
|
||||
protocol: ip_packet.next_header(),
|
||||
local_address: IpAddress::Ipv6(ip_packet.src_addr()),
|
||||
local_port: src_port,
|
||||
remote_address: IpAddress::Ipv6(ip_packet.dst_addr()),
|
||||
remote_port: dst_port,
|
||||
}),
|
||||
Direction::Inbound => Ok(Key {
|
||||
protocol: ip_packet.next_header(),
|
||||
local_address: IpAddress::Ipv6(ip_packet.dst_addr()),
|
||||
local_port: dst_port,
|
||||
remote_address: IpAddress::Ipv6(ip_packet.src_addr()),
|
||||
remote_port: src_port,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Converts a given key into connection information.
|
||||
//
|
||||
// This function takes a key, packet id, process id, and direction as input.
|
||||
// It then uses these to create a new `ConnectionInfoV6` or `ConnectionInfoV4` object,
|
||||
// depending on whether the IP addresses in the key are IPv6 or IPv4 respectively.
|
||||
//
|
||||
// # Arguments
|
||||
//
|
||||
// * `key` - A reference to the key object containing the connection details.
|
||||
// * `packet_id` - The id of the packet.
|
||||
// * `process_id` - The id of the process.
|
||||
// * `direction` - The direction of the connection.
|
||||
//
|
||||
// # Returns
|
||||
//
|
||||
// * `Some(Box<dyn Info>)` - A boxed `Info` trait object if the key contains valid IPv4 or IPv6 addresses.
|
||||
// * `None` - If the key does not contain valid IPv4 or IPv6 addresses.
|
||||
// pub fn key_to_connection_info(
|
||||
// key: &Key,
|
||||
// packet_id: u64,
|
||||
// process_id: u64,
|
||||
// direction: Direction,
|
||||
// payload: &[u8],
|
||||
// ) -> Option<Info> {
|
||||
// let (local_port, remote_port) = match key.protocol {
|
||||
// IpProtocol::Tcp | IpProtocol::Udp => (key.local_port, key.remote_port),
|
||||
// _ => (0, 0),
|
||||
// };
|
||||
|
||||
// match (key.local_address, key.remote_address) {
|
||||
// (IpAddress::Ipv6(local_ip), IpAddress::Ipv6(remote_ip)) if key.is_ipv6() => {
|
||||
// Some(protocol::info::connection_info_v6(
|
||||
// packet_id,
|
||||
// process_id,
|
||||
// direction as u8,
|
||||
// u8::from(key.protocol),
|
||||
// local_ip.0,
|
||||
// remote_ip.0,
|
||||
// local_port,
|
||||
// remote_port,
|
||||
// payload,
|
||||
// ))
|
||||
// }
|
||||
// (IpAddress::Ipv4(local_ip), IpAddress::Ipv4(remote_ip)) => {
|
||||
// Some(protocol::info::connection_info_v4(
|
||||
// packet_id,
|
||||
// process_id,
|
||||
// direction as u8,
|
||||
// u8::from(key.protocol),
|
||||
// local_ip.0,
|
||||
// remote_ip.0,
|
||||
// local_port,
|
||||
// remote_port,
|
||||
// payload,
|
||||
// ))
|
||||
// }
|
||||
// _ => None,
|
||||
// }
|
||||
// }
|
||||
203
windows_kext/driver/src/stream_callouts.rs
Normal file
203
windows_kext/driver/src/stream_callouts.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use smoltcp::wire::{Ipv4Address, Ipv6Address};
|
||||
use wdk::filter_engine::{callout_data::CalloutData, layer, net_buffer::NetBufferListIter};
|
||||
|
||||
use crate::{bandwidth, connection::Direction};
|
||||
|
||||
pub fn stream_layer_tcp_v4(data: CalloutData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let mut direction = Direction::Outbound;
|
||||
let data_length = if let Some(packet) = data.get_stream_callout_packet() {
|
||||
if packet.is_receive() {
|
||||
direction = Direction::Inbound;
|
||||
}
|
||||
packet.get_data_len()
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
type Fields = layer::FieldsStreamV4;
|
||||
let local_ip = Ipv4Address::from_bytes(
|
||||
&data
|
||||
.get_value_u32(Fields::IpLocalAddress as usize)
|
||||
.to_be_bytes(),
|
||||
);
|
||||
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
|
||||
let remote_ip = Ipv4Address::from_bytes(
|
||||
&data
|
||||
.get_value_u32(Fields::IpRemoteAddress as usize)
|
||||
.to_be_bytes(),
|
||||
);
|
||||
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
|
||||
match direction {
|
||||
Direction::Outbound => {
|
||||
device.bandwidth_stats.update_tcp_v4_tx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
Direction::Inbound => {
|
||||
device.bandwidth_stats.update_tcp_v4_rx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_layer_tcp_v6(data: CalloutData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let mut direction = Direction::Outbound;
|
||||
let data_length = if let Some(packet) = data.get_stream_callout_packet() {
|
||||
if packet.is_receive() {
|
||||
direction = Direction::Inbound;
|
||||
}
|
||||
packet.get_data_len()
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
type Fields = layer::FieldsStreamV6;
|
||||
if data_length == 0 {
|
||||
return;
|
||||
}
|
||||
let local_ip =
|
||||
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
|
||||
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
|
||||
let remote_ip =
|
||||
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
|
||||
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
|
||||
match direction {
|
||||
Direction::Outbound => {
|
||||
device.bandwidth_stats.update_tcp_v6_tx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
Direction::Inbound => {
|
||||
device.bandwidth_stats.update_tcp_v6_rx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_layer_udp_v4(data: CalloutData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let mut data_length: usize = 0;
|
||||
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
|
||||
data_length += nbl.get_data_length() as usize;
|
||||
}
|
||||
type Fields = layer::FieldsDatagramDataV4;
|
||||
let mut direction = Direction::Inbound;
|
||||
if data.get_value_u8(Fields::Direction as usize) == 0 {
|
||||
direction = Direction::Outbound;
|
||||
}
|
||||
|
||||
let local_ip = Ipv4Address::from_bytes(
|
||||
&data
|
||||
.get_value_u32(Fields::IpLocalAddress as usize)
|
||||
.to_be_bytes(),
|
||||
);
|
||||
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
|
||||
let remote_ip = Ipv4Address::from_bytes(
|
||||
&data
|
||||
.get_value_u32(Fields::IpRemoteAddress as usize)
|
||||
.to_be_bytes(),
|
||||
);
|
||||
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
|
||||
match direction {
|
||||
Direction::Outbound => {
|
||||
device.bandwidth_stats.update_udp_v4_tx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
Direction::Inbound => {
|
||||
device.bandwidth_stats.update_udp_v4_rx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_layer_udp_v6(data: CalloutData) {
|
||||
let Some(device) = crate::entry::get_device() else {
|
||||
return;
|
||||
};
|
||||
let mut data_length: usize = 0;
|
||||
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
|
||||
data_length += nbl.get_data_length() as usize;
|
||||
}
|
||||
type Fields = layer::FieldsDatagramDataV6;
|
||||
let mut direction = Direction::Inbound;
|
||||
if data.get_value_u8(Fields::Direction as usize) == 0 {
|
||||
direction = Direction::Outbound;
|
||||
}
|
||||
|
||||
let local_ip =
|
||||
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
|
||||
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
|
||||
let remote_ip =
|
||||
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
|
||||
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
|
||||
match direction {
|
||||
Direction::Outbound => {
|
||||
device.bandwidth_stats.update_udp_v6_tx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
Direction::Inbound => {
|
||||
device.bandwidth_stats.update_udp_v6_rx(
|
||||
bandwidth::Key {
|
||||
local_ip,
|
||||
local_port,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
},
|
||||
data_length,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user