syndicate_utils/src/net_mapper.nim

168 lines
5.1 KiB
Nim

# SPDX-FileCopyrightText: ☭ Emery Hemingway
# SPDX-License-Identifier: Unlicense
## A ping utility for Syndicate.
import std/[asyncdispatch, asyncnet, monotimes, nativesockets, net, os, strutils, tables, times]
import preserves
import syndicate, syndicate/relays
import ./schema/net_mapper
#[
var
SOL_IP {.importc, nodecl, header: "<sys/socket.h>".}: int
IP_TTL {.importc, nodecl, header: "<netinet/in.h>".}: int
]#
proc toPreservesHook(address: IpAddress): Value = toPreserves($address)
proc fromPreservesHook(address: var IpAddress; pr: Value): bool =
try:
if pr.isString:
address = parseIpAddress(pr.string)
result = true
except ValueError: discard
when isMainModule:
# verify that the hook catches
var ip: IpAddress
assert fromPreservesHook(ip, toPreservesHook(ip))
type
IcmpHeader {.packed.} = object
`type`: uint8
code: uint8
checksum: uint16
IcmpEchoFields {.packed.} = object
header: IcmpHeader
identifier: array[2, byte]
sequenceNumber: uint16
IcmpEcho {.union.} = object
fields: IcmpEchoFields
buffer: array[8, uint8]
IcmpTypes = enum
icmpEchoReply = 0,
icmpEcho = 8,
proc initIcmpEcho(): IcmpEcho =
result.fields.header.`type` = uint8 icmpEcho
# doAssert urandom(result.fields.identifier) # Linux does this?
proc updateChecksum(msg: var IcmpEcho) =
var sum: uint32
msg.fields.header.checksum = 0
for n in cast[array[4, uint16]](msg.buffer): sum = sum + uint32(n)
while (sum and 0xffff0000'u32) != 0:
sum = (sum and 0xffff) + (sum shr 16)
msg.fields.header.checksum = not uint16(sum)
proc match(a, b: IcmpEchoFields): bool =
({a.header.type, b.header.type} == {uint8 icmpEcho, uint8 icmpEchoReply}) and
(a.header.code == b.header.code) and
(a.sequenceNumber == b.sequenceNumber)
type
Pinger = ref object
facet: Facet
ds: Cap
rtt: RoundTripTime
rttHandle: Handle
sum: Duration
count: int64
msg: IcmpEcho
socket: AsyncSocket
sad: Sockaddr_storage
sadLen: SockLen
interval: Duration
proc newPinger(address: IpAddress; facet: Facet; ds: Cap): Pinger =
result = Pinger(
facet: facet,
ds: ds,
rtt: RoundTripTime(address: $address),
msg: initIcmpEcho(),
socket: newAsyncSocket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP, false, true),
interval: initDuration(milliseconds = 500))
toSockAddr(address, Port 0, result.sad, result.sadLen)
# setSockOptInt(getFd socket, SOL_IP, IP_TTL, _)
proc close(ping: Pinger) = close(ping.socket)
proc sqr(dur: Duration): Duration =
let us = dur.inMicroseconds
initDuration(microseconds = us * us)
proc update(ping: Pinger; dur: Duration) {.inline.} =
let secs = dur.inMicroseconds.float / 1_000_000.0
if ping.count == 0: (ping.rtt.minimum, ping.rtt.maximum) = (secs, secs)
elif secs < ping.rtt.minimum: ping.rtt.minimum = secs
elif secs > ping.rtt.maximum: ping.rtt.maximum = secs
ping.sum = ping.sum + dur
inc ping.count
ping.rtt.average = inMicroseconds(ping.sum div ping.count).float / 1_000_000.0
proc exchangeEcho(ping: Pinger) {.async.} =
inc ping.msg.fields.sequenceNumber
# updateChecksum(ping.msg) # Linux does this?
let
a = getMonoTime()
r = sendto(ping.socket.getFd,
unsafeAddr ping.msg.buffer[0], ping.msg.buffer.len, 0,
cast[ptr SockAddr](unsafeAddr ping.sad), # neckbeard loser API
ping.sadLen)
if r == -1'i32:
let osError = osLastError()
raiseOSError(osError)
while true:
var
(data, address, _) = await recvFrom(ping.socket, 128)
b = getMonoTime()
if address != $ping.rtt.address:
stderr.writeLine "want ICMP from ", ping.rtt.address, " but received from ", address, " instead"
elif data.len >= ping.msg.buffer.len:
let
period = b - a
resp = cast[ptr IcmpEcho](unsafeAddr data[0])
if match(ping.msg.fields, resp.fields):
update(ping, period)
return
else:
stderr.writeLine "ICMP mismatch"
else:
stderr.writeLine "reply data has a bad length ", data.len
proc kick(ping: Pinger) {.gcsafe.} =
if not ping.socket.isClosed:
addTimer(ping.interval.inMilliseconds.int, oneshot = true) do (fd: AsyncFD) -> bool:
let fut = exchangeEcho(ping)
fut.addCallback do ():
if fut.failed and ping.rttHandle != Handle(0):
ping.facet.run do (turn: var Turn):
retract(turn, ping.rttHandle)
reset ping.rttHandle
else:
ping.facet.run do (turn: var Turn):
replace(turn, ping.ds, ping.rttHandle, ping.rtt)
if ping.interval < initDuration(seconds = 20):
ping.interval = ping.interval * 2
kick(ping)
type Args {.preservesDictionary.} = object
dataspace: Cap
runActor("net_mapper") do (root: Cap; turn: var Turn):
connectStdio(turn, root)
let rttObserver = ?Observe(pattern: !RoundTripTime) ?? {0: grabLit()}
during(turn, root, ?:Args) do (ds: Cap):
during(turn, ds, rttObserver) do (address: IpAddress):
var ping: Pinger
if address.family == IpAddressFamily.IPv4:
ping = newPinger(address, turn.facet, ds)
kick(ping)
do:
if not ping.isNil: close(ping)