Repair noise session introduction

This commit is contained in:
Tony Garnock-Jones 2024-03-28 16:32:46 +01:00
parent 5090625f47
commit 2ed2b38edc
1 changed files with 53 additions and 34 deletions

View File

@ -267,18 +267,18 @@ fn await_bind_noise(
ds: &mut Arc<Cap>, ds: &mut Arc<Cap>,
t: &mut Activation, t: &mut Activation,
service_selector: AnyValue, service_selector: AnyValue,
initiator_session: Arc<Cap>, observer: Arc<Cap>,
direct_resolution_facet: FacetId, direct_resolution_facet: FacetId,
) -> ActorResult { ) -> ActorResult {
let handler = syndicate::entity(()) let handler = syndicate::entity(())
.on_asserted_facet(move |_state, t, a: AnyValue| { .on_asserted_facet(move |_state, t, a: AnyValue| {
t.stop_facet(direct_resolution_facet); t.stop_facet(direct_resolution_facet);
let initiator_session = Arc::clone(&initiator_session); let observer = Arc::clone(&observer);
t.spawn_link(None, move |t| { t.spawn_link(None, move |t| {
let bindings = a.value().to_sequence()?; let bindings = a.value().to_sequence()?;
let spec = validate_noise_spec(language().parse(&bindings[0])?)?; let spec = validate_noise_spec(language().parse(&bindings[0])?)?;
let service = bindings[1].value().to_embedded()?; let service = bindings[1].value().to_embedded()?;
run_noise_responder(t, spec, initiator_session, Arc::clone(service)) run_noise_responder(t, spec, observer, Arc::clone(service))
}); });
Ok(()) Ok(())
}) })
@ -293,25 +293,50 @@ fn await_bind_noise(
Ok(()) Ok(())
} }
struct ResponderDetails { type HandshakeState = noise_protocol::HandshakeState<X25519, ChaCha20Poly1305, Blake2s>;
initiator_session: Arc<Cap>,
service: Arc<Cap>,
}
struct ResponderTransport {
relay_input: Arc<Mutex<Option<TunnelRelay>>>,
c_recv: CipherState<ChaCha20Poly1305>
}
enum ResponderState { enum ResponderState {
Handshake(ResponderDetails, noise_protocol::HandshakeState<X25519, ChaCha20Poly1305, Blake2s>), Invalid, // used during state transitions
Transport(ResponderTransport), Introduction {
service: Arc<Cap>,
hs: HandshakeState,
},
Handshake {
initiator_session: Arc<Cap>,
service: Arc<Cap>,
hs: HandshakeState,
},
Transport {
relay_input: Arc<Mutex<Option<TunnelRelay>>>,
c_recv: CipherState<ChaCha20Poly1305>,
},
} }
impl Entity<noise::Packet> for ResponderState { impl Entity<noise::SessionItem> for ResponderState {
fn message(&mut self, t: &mut Activation, p: noise::Packet) -> ActorResult { fn assert(&mut self, _t: &mut Activation, item: noise::SessionItem, _handle: Handle) -> ActorResult {
let initiator_session = match item {
noise::SessionItem::Initiator(i_box) => i_box.initiator_session,
noise::SessionItem::Packet(_) => Err("Unexpected Packet assertion")?,
};
match std::mem::replace(self, ResponderState::Invalid) {
ResponderState::Introduction { service, hs } => {
*self = ResponderState::Handshake { initiator_session, service, hs };
Ok(())
}
_ =>
Err("Received second Initiator")?,
}
}
fn message(&mut self, t: &mut Activation, item: noise::SessionItem) -> ActorResult {
let p = match item {
noise::SessionItem::Initiator(_) => Err("Unexpected Initiator message")?,
noise::SessionItem::Packet(p_box) => *p_box,
};
match self { match self {
ResponderState::Handshake(details, hs) => match p { ResponderState::Invalid | ResponderState::Introduction { .. } =>
Err("Received Packet in invalid ResponderState")?,
ResponderState::Handshake { initiator_session, service, hs } => match p {
noise::Packet::Complete(bs) => { noise::Packet::Complete(bs) => {
if bs.len() < hs.get_next_message_overhead() { if bs.len() < hs.get_next_message_overhead() {
Err("Invalid handshake message for pattern")?; Err("Invalid handshake message for pattern")?;
@ -322,14 +347,13 @@ impl Entity<noise::Packet> for ResponderState {
hs.read_message(&bs, &mut [])?; hs.read_message(&bs, &mut [])?;
let mut reply = vec![0u8; hs.get_next_message_overhead()]; let mut reply = vec![0u8; hs.get_next_message_overhead()];
hs.write_message(&[], &mut reply[..])?; hs.write_message(&[], &mut reply[..])?;
details.initiator_session.message(t, language(), &noise::Packet::Complete(reply.into())); initiator_session.message(t, language(), &noise::Packet::Complete(reply.into()));
if hs.completed() { if hs.completed() {
let (c_recv, mut c_send) = hs.get_ciphers(); let (c_recv, mut c_send) = hs.get_ciphers();
let (_, relay_input, mut relay_output) = let (_, relay_input, mut relay_output) =
TunnelRelay::_run(t, Some(Arc::clone(&details.service)), None, false); TunnelRelay::_run(t, Some(Arc::clone(service)), None, false);
let trace_collector = t.trace_collector(); let trace_collector = t.trace_collector();
let transport = ResponderTransport { relay_input, c_recv }; let initiator_session = Arc::clone(initiator_session);
let initiator_session = Arc::clone(&details.initiator_session);
let relay_output_name = Some(AnyValue::symbol("relay_output")); let relay_output_name = Some(AnyValue::symbol("relay_output"));
let transport_facet = t.facet_ref(); let transport_facet = t.facet_ref();
t.linked_task(relay_output_name.clone(), async move { t.linked_task(relay_output_name.clone(), async move {
@ -360,25 +384,25 @@ impl Entity<noise::Packet> for ResponderState {
} }
Ok(LinkedTaskTermination::Normal) Ok(LinkedTaskTermination::Normal)
}); });
*self = ResponderState::Transport(transport); *self = ResponderState::Transport { relay_input, c_recv };
} }
} }
_ => Err("Fragmented handshake is not allowed")?, _ => Err("Fragmented handshake is not allowed")?,
}, },
ResponderState::Transport(transport) => { ResponderState::Transport { relay_input, c_recv } => {
let bs = match p { let bs = match p {
noise::Packet::Complete(bs) => noise::Packet::Complete(bs) =>
transport.c_recv.decrypt_vec(&bs[..]).map_err(|_| "Cannot decrypt packet")?, c_recv.decrypt_vec(&bs[..]).map_err(|_| "Cannot decrypt packet")?,
noise::Packet::Fragmented(pieces) => { noise::Packet::Fragmented(pieces) => {
let mut result = Vec::with_capacity(1024); let mut result = Vec::with_capacity(1024);
for piece in pieces { for piece in pieces {
result.extend(transport.c_recv.decrypt_vec(&piece[..]) result.extend(c_recv.decrypt_vec(&piece[..])
.map_err(|_| "Cannot decrypt packet fragment")?); .map_err(|_| "Cannot decrypt packet fragment")?);
} }
result result
} }
}; };
let mut g = transport.relay_input.lock(); let mut g = relay_input.lock();
let tr = g.as_mut().expect("initialized"); let tr = g.as_mut().expect("initialized");
tr.handle_inbound_datagram(t, &bs[..])?; tr.handle_inbound_datagram(t, &bs[..])?;
} }
@ -446,7 +470,7 @@ fn lookup_pattern(name: &str) -> Option<HandshakePattern> {
fn run_noise_responder( fn run_noise_responder(
t: &mut Activation, t: &mut Activation,
spec: ValidatedNoiseSpec, spec: ValidatedNoiseSpec,
initiator_session: Arc<Cap>, observer: Arc<Cap>,
service: Arc<Cap>, service: Arc<Cap>,
) -> ActorResult { ) -> ActorResult {
let hs = { let hs = {
@ -469,13 +493,8 @@ fn run_noise_responder(
hs hs
}; };
let details = ResponderDetails {
initiator_session: initiator_session.clone(),
service,
};
let responder_session = let responder_session =
Cap::guard(crate::Language::arc(), t.create(ResponderState::Handshake(details, hs))); Cap::guard(crate::Language::arc(), t.create(ResponderState::Introduction{ service, hs }));
initiator_session.assert(t, language(), &gatekeeper::Resolved::Accepted { responder_session }); observer.assert(t, language(), &gatekeeper::Resolved::Accepted { responder_session });
Ok(()) Ok(())
} }