diff --git a/packages/network-transport-inmemory/src/Network/Transport/InMemory/Internal.hs b/packages/network-transport-inmemory/src/Network/Transport/InMemory/Internal.hs index 1000fd1e..d5762c95 100644 --- a/packages/network-transport-inmemory/src/Network/Transport/InMemory/Internal.hs +++ b/packages/network-transport-inmemory/src/Network/Transport/InMemory/Internal.hs @@ -156,24 +156,34 @@ apiNewEndPoint state = handle (return . Left) $ atomically $ do TransportError ResolveMulticastGroupUnsupported "Multicast not supported" apiCloseEndPoint :: TVar TransportState -> EndPointAddress -> IO () -apiCloseEndPoint state addr = atomically $ whenValidTransportState state $ \vst -> - forM_ (vst ^. localEndPointAt addr) $ \lep -> do - old <- swapTVar (localEndPointState lep) LocalEndPointClosed - case old of - LocalEndPointClosed -> return () - LocalEndPointValid lepvst -> do - forM_ (Map.elems (lepvst ^. connections)) $ \lconn -> do - st <- swapTVar (localConnectionState lconn) LocalConnectionClosed - case st of - LocalConnectionClosed -> return () - LocalConnectionFailed -> return () - _ -> forM_ (vst ^. localEndPointAt (localConnectionRemoteAddress lconn)) $ \thep -> - whenValidLocalEndPointState thep $ \_ -> do - writeTChan (localEndPointChannel thep) - (ConnectionClosed (localConnectionId lconn)) - writeTChan (localEndPointChannel lep) EndPointClosed - writeTVar (localEndPointState lep) LocalEndPointClosed - writeTVar state (TransportValid $ (localEndPoints ^: Map.delete addr) vst) +apiCloseEndPoint state addr = atomically $ whenValidTransportState state $ \vst -> do + + forM_ (Map.toList $ _localEndPoints vst) $ + \(theirAddr, lep) -> do + + if theirAddr == addr + then do + old <- swapTVar (localEndPointState lep) LocalEndPointClosed + case old of + LocalEndPointClosed -> return () + LocalEndPointValid lepvst -> do + forM_ (Map.elems (lepvst ^. connections)) $ \lconn -> do + st <- swapTVar (localConnectionState lconn) LocalConnectionClosed + case st of + LocalConnectionClosed -> return () + LocalConnectionFailed -> return () + _ -> do + forM_ (vst ^. localEndPointAt (localConnectionRemoteAddress lconn)) $ \thep -> + whenValidLocalEndPointState thep $ \_ -> do + writeTChan (localEndPointChannel thep) + (ConnectionClosed (localConnectionId lconn)) + writeTChan (localEndPointChannel lep) EndPointClosed + writeTVar (localEndPointState lep) LocalEndPointClosed + + else do + apiBreakConnection state addr theirAddr "remote endpoint disconnected" + + writeTVar state (TransportValid $ (localEndPoints ^: Map.delete addr) vst) -- | Tear down functions that should be called in case if conncetion fails. apiBreakConnection :: TVar TransportState