# SPDX-FileCopyrightText: ☭ Emery Hemingway # SPDX-License-Identifier: Unlicense ## A ping utility for Syndicate. import std/[asyncdispatch, asyncnet, monotimes, nativesockets, net, os, strutils, tables, times] import preserves import syndicate, syndicate/patterns import ./schema/net_mapper #[ var SOL_IP {.importc, nodecl, header: "".}: int IP_TTL {.importc, nodecl, header: "".}: int ]# proc toPreserveHook(address: IpAddress; E: typedesc): Preserve[E] = toPreserve($address, E) proc fromPreserveHook[E](address: var IpAddress; pr: Preserve[E]): bool = try: if pr.isString: address = parseIpAddress(pr.string) result = true except ValueError: discard when isMainModule: # verify that the hook catches var ip: IpAddress assert fromPreserveHook(ip, toPreserveHook(ip, void)) type IcmpHeader {.packed.} = object `type`: uint8 code: uint8 checksum: uint16 IcmpEchoFields {.packed.} = object header: IcmpHeader identifier: array[2, byte] sequenceNumber: uint16 IcmpEcho {.union.} = object fields: IcmpEchoFields buffer: array[8, uint8] IcmpTypes = enum icmpEchoReply = 0, icmpEcho = 8, proc initIcmpEcho(): IcmpEcho = result.fields.header.`type` = uint8 icmpEcho # doAssert urandom(result.fields.identifier) # Linux does this? proc updateChecksum(msg: var IcmpEcho) = var sum: uint32 msg.fields.header.checksum = 0 for n in cast[array[4, uint16]](msg.buffer): sum = sum + uint32(n) while (sum and 0xffff0000'u32) != 0: sum = (sum and 0xffff) + (sum shr 16) msg.fields.header.checksum = not uint16(sum) proc match(a, b: IcmpEchoFields): bool = ({a.header.type, b.header.type} == {uint8 icmpEcho, uint8 icmpEchoReply}) and (a.header.code == b.header.code) and (a.sequenceNumber == b.sequenceNumber) type Pinger = ref object facet: Facet ds: Cap rtt: RoundTripTime rttHandle: Handle sum: Duration count: int64 msg: IcmpEcho socket: AsyncSocket sad: Sockaddr_storage sadLen: SockLen interval: Duration proc newPinger(address: IpAddress; facet: Facet; ds: Cap): Pinger = result = Pinger( facet: facet, ds: ds, rtt: RoundTripTime(address: $address), msg: initIcmpEcho(), socket: newAsyncSocket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP, false, true), interval: initDuration(milliseconds = 500)) toSockAddr(address, Port 0, result.sad, result.sadLen) # setSockOptInt(getFd socket, SOL_IP, IP_TTL, _) proc close(ping: Pinger) = close(ping.socket) proc sqr(dur: Duration): Duration = let us = dur.inMicroseconds initDuration(microseconds = us * us) proc update(ping: Pinger; dur: Duration) {.inline.} = let secs = dur.inMicroseconds.float / 1_000_000.0 if ping.count == 0: (ping.rtt.minimum, ping.rtt.maximum) = (secs, secs) elif secs < ping.rtt.minimum: ping.rtt.minimum = secs elif secs > ping.rtt.maximum: ping.rtt.maximum = secs ping.sum = ping.sum + dur inc ping.count ping.rtt.average = inMicroseconds(ping.sum div ping.count).float / 1_000_000.0 proc exchangeEcho(ping: Pinger) {.async.} = inc ping.msg.fields.sequenceNumber # updateChecksum(ping.msg) # Linux does this? let a = getMonoTime() r = sendto(ping.socket.getFd, unsafeAddr ping.msg.buffer[0], ping.msg.buffer.len, 0, cast[ptr SockAddr](unsafeAddr ping.sad), # neckbeard loser API ping.sadLen) if r == -1'i32: let osError = osLastError() raiseOSError(osError) while true: var (data, address, _) = await recvFrom(ping.socket, 128) b = getMonoTime() if address != $ping.rtt.address: stderr.writeLine "want ICMP from ", ping.rtt.address, " but received from ", address, " instead" elif data.len >= ping.msg.buffer.len: let period = b - a resp = cast[ptr IcmpEcho](unsafeAddr data[0]) if match(ping.msg.fields, resp.fields): update(ping, period) return else: stderr.writeLine "ICMP mismatch" else: stderr.writeLine "reply data has a bad length ", data.len proc kick(ping: Pinger) {.gcsafe.} = if not ping.socket.isClosed: addTimer(ping.interval.inMilliseconds.int, oneshot = true) do (fd: AsyncFD) -> bool: let fut = exchangeEcho(ping) fut.addCallback do (): if fut.failed and ping.rttHandle != Handle(0): ping.facet.run do (turn: var Turn): retract(turn, ping.rttHandle) reset ping.rttHandle else: ping.facet.run do (turn: var Turn): replace(turn, ping.ds, ping.rttHandle, ping.rtt) if ping.interval < initDuration(seconds = 20): ping.interval = ping.interval * 2 kick(ping) type Args {.preservesDictionary.} = object dataspace: Cap runActor("net_mapper") do (root: Cap; turn: var Turn): connectStdio(root, turn) let rttObserver = ?Observe(pattern: !RoundTripTime) ?? {0: grabLit()} during(turn, root, ?Args) do (ds: Cap): during(turn, ds, rttObserver) do (address: IpAddress): var ping: Pinger if address.family == IpAddressFamily.IPv4: ping = newPinger(address, turn.facet, ds) kick(ping) do: if not ping.isNil: close(ping)