Add multi connection support in secure session manager
diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp
index 0e23ebb..5ea4ed1 100644
--- a/src/controller/CHIPDeviceController.cpp
+++ b/src/controller/CHIPDeviceController.cpp
@@ -51,15 +51,16 @@
ChipDeviceController::ChipDeviceController()
{
- mState = kState_NotInitialized;
- AppState = NULL;
- mConState = kConnectionState_NotConnected;
- mDeviceCon = NULL;
- mCurReqMsg = NULL;
- mOnError = NULL;
- mDeviceAddr = IPAddress::Any;
- mDevicePort = CHIP_PORT;
- mLocalDeviceId = 0;
+ mState = kState_NotInitialized;
+ AppState = NULL;
+ mConState = kConnectionState_NotConnected;
+ mSessionManager = NULL;
+ mCurReqMsg = NULL;
+ mOnError = NULL;
+ mOnNewConnection = NULL;
+ mDeviceAddr = IPAddress::Any;
+ mDevicePort = CHIP_PORT;
+ mLocalDeviceId = 0;
memset(&mOnComplete, 0, sizeof(mOnComplete));
}
@@ -105,10 +106,10 @@
CHIP_ERROR err = CHIP_NO_ERROR;
mState = kState_NotInitialized;
- if (mDeviceCon != NULL)
+ if (mSessionManager != NULL)
{
- delete mDeviceCon;
- mDeviceCon = NULL;
+ delete mSessionManager;
+ mSessionManager = NULL;
}
mSystemLayer->Shutdown();
mInetLayer->Shutdown();
@@ -119,36 +120,41 @@
mConState = kConnectionState_NotConnected;
memset(&mOnComplete, 0, sizeof(mOnComplete));
- mOnError = NULL;
- mMessageNumber = 0;
+ mOnError = NULL;
+ mOnNewConnection = NULL;
+ mMessageNumber = 0;
mRemoteDeviceId.ClearValue();
return err;
}
CHIP_ERROR ChipDeviceController::ConnectDevice(NodeId remoteDeviceId, IPAddress deviceAddr, void * appReqState,
- MessageReceiveHandler onMessageReceived, ErrorHandler onError, uint16_t devicePort)
+ NewConnectionHandler onConnected, MessageReceiveHandler onMessageReceived,
+ ErrorHandler onError, uint16_t devicePort)
{
CHIP_ERROR err = CHIP_NO_ERROR;
- if (mState != kState_Initialized || mDeviceCon != NULL || mConState != kConnectionState_NotConnected)
+ if (mState != kState_Initialized || mSessionManager != NULL || mConState != kConnectionState_NotConnected)
{
return CHIP_ERROR_INCORRECT_STATE;
}
- mRemoteDeviceId = Optional<NodeId>::Value(remoteDeviceId);
- mDeviceAddr = deviceAddr;
- mDevicePort = devicePort;
- mAppReqState = appReqState;
- mDeviceCon = new SecureSessionMgr();
+ mRemoteDeviceId = Optional<NodeId>::Value(remoteDeviceId);
+ mDeviceAddr = deviceAddr;
+ mDevicePort = devicePort;
+ mAppReqState = appReqState;
+ mOnNewConnection = onConnected;
- err = mDeviceCon->Init(mLocalDeviceId, mInetLayer, Transport::UdpListenParameters().SetAddressType(deviceAddr.Type()));
+ mSessionManager = new SecureSessionMgr();
+
+ err = mSessionManager->Init(mLocalDeviceId, mInetLayer, Transport::UdpListenParameters().SetAddressType(deviceAddr.Type()));
SuccessOrExit(err);
- err = mDeviceCon->Connect(remoteDeviceId, Transport::PeerAddress::UDP(deviceAddr, devicePort));
- SuccessOrExit(err);
+ mSessionManager->SetMessageReceiveHandler(OnReceiveMessage, this);
+ mSessionManager->SetNewConnectionHandler(OnNewConnection, this);
- mDeviceCon->SetMessageReceiveHandler(OnReceiveMessage, this);
+ err = mSessionManager->Connect(remoteDeviceId, Transport::PeerAddress::UDP(deviceAddr, devicePort));
+ SuccessOrExit(err);
mOnComplete.Response = onMessageReceived;
mOnError = onError;
@@ -160,26 +166,28 @@
if (err != CHIP_NO_ERROR)
{
- if (mDeviceCon != NULL)
+ if (mSessionManager != NULL)
{
- delete mDeviceCon;
- mDeviceCon = NULL;
+ delete mSessionManager;
+ mSessionManager = NULL;
}
}
return err;
}
-CHIP_ERROR ChipDeviceController::ManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
- const unsigned char * local_private_key, const size_t private_key_length)
+CHIP_ERROR ChipDeviceController::ManualKeyExchange(Transport::PeerConnectionState * state, const unsigned char * remote_public_key,
+ const size_t public_key_length, const unsigned char * local_private_key,
+ const size_t private_key_length)
{
CHIP_ERROR err = CHIP_NO_ERROR;
- if (!IsConnected() || mDeviceCon == NULL)
+ if (!IsConnected() || mSessionManager == NULL)
{
return CHIP_ERROR_INCORRECT_STATE;
}
- err = mDeviceCon->ManualKeyExchange(remote_public_key, public_key_length, local_private_key, private_key_length);
+ err = state->GetSecureSession().TemporaryManualKeyExchange(remote_public_key, public_key_length, local_private_key,
+ private_key_length);
SuccessOrExit(err);
mConState = kConnectionState_SecureConnected;
@@ -220,9 +228,9 @@
return CHIP_ERROR_INCORRECT_STATE;
}
- delete mDeviceCon;
- mDeviceCon = NULL;
- mConState = kConnectionState_NotConnected;
+ delete mSessionManager;
+ mSessionManager = NULL;
+ mConState = kConnectionState_NotConnected;
return err;
};
@@ -235,7 +243,7 @@
mAppReqState = appReqState;
VerifyOrExit(IsSecurelyConnected(), err = CHIP_ERROR_INCORRECT_STATE);
- err = mDeviceCon->SendMessage(mRemoteDeviceId.Value(), buffer);
+ err = mSessionManager->SendMessage(mRemoteDeviceId.Value(), buffer);
exit:
return err;
@@ -322,7 +330,15 @@
}
}
-void ChipDeviceController::OnReceiveMessage(const MessageHeader & header, const IPPacketInfo & pktInfo,
+void ChipDeviceController::OnNewConnection(Transport::PeerConnectionState * state, ChipDeviceController * mgr)
+{
+ if (mgr->mOnNewConnection)
+ {
+ mgr->mOnNewConnection(mgr, state, mgr->mAppReqState);
+ }
+}
+
+void ChipDeviceController::OnReceiveMessage(const MessageHeader & header, Transport::PeerConnectionState * state,
System::PacketBuffer * msgBuf, ChipDeviceController * mgr)
{
if (header.GetSourceNodeId().HasValue())
@@ -339,7 +355,7 @@
}
if (mgr->IsSecurelyConnected() && mgr->mOnComplete.Response != NULL)
{
- mgr->mOnComplete.Response(mgr, mgr->mAppReqState, msgBuf, &pktInfo);
+ mgr->mOnComplete.Response(mgr, mgr->mAppReqState, msgBuf);
}
}
diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h
index 510d79f..ab8ef9b 100644
--- a/src/controller/CHIPDeviceController.h
+++ b/src/controller/CHIPDeviceController.h
@@ -41,11 +41,12 @@
class ChipDeviceController;
extern "C" {
+typedef void (*NewConnectionHandler)(ChipDeviceController * deviceController, Transport::PeerConnectionState * state,
+ void * appReqState);
typedef void (*CompleteHandler)(ChipDeviceController * deviceController, void * appReqState);
typedef void (*ErrorHandler)(ChipDeviceController * deviceController, void * appReqState, CHIP_ERROR err,
const IPPacketInfo * pktInfo);
-typedef void (*MessageReceiveHandler)(ChipDeviceController * deviceController, void * appReqState, System::PacketBuffer * payload,
- const IPPacketInfo * pktInfo);
+typedef void (*MessageReceiveHandler)(ChipDeviceController * deviceController, void * appReqState, System::PacketBuffer * payload);
};
class DLL_EXPORT ChipDeviceController
@@ -66,12 +67,13 @@
* @param[in] remoteDeviceId The remote device Id.
* @param[in] deviceAddr The IPAddress of the requested Device
* @param[in] appReqState Application specific context to be passed back when a message is received or on error
+ * @param[in] onConnected Callback for when the connection is established
* @param[in] onMessageReceived Callback for when a message is received
* @param[in] onError Callback for when an error occurs
* @param[in] devicePort [Optional] The CHIP Device's port, defaults to CHIP_PORT
* @return CHIP_ERROR The connection status
*/
- CHIP_ERROR ConnectDevice(NodeId remoteDeviceId, IPAddress deviceAddr, void * appReqState,
+ CHIP_ERROR ConnectDevice(NodeId remoteDeviceId, IPAddress deviceAddr, void * appReqState, NewConnectionHandler onConnected,
MessageReceiveHandler onMessageReceived, ErrorHandler onError, uint16_t devicePort = CHIP_PORT);
/**
@@ -80,14 +82,16 @@
* until we have automatic key exchange in place. The function is useful only for
* example applications for now. It will eventually be removed.
*
+ * @param state Peer connection for which to establish the key
* @param remote_public_key A pointer to peer's public key
* @param public_key_length Length of remote_public_key
* @param local_private_key A pointer to local private key
* @param private_key_length Length of local_private_key
* @return CHIP_ERROR The result of key derivation
*/
- CHIP_ERROR ManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
- const unsigned char * local_private_key, const size_t private_key_length);
+ CHIP_ERROR ManualKeyExchange(Transport::PeerConnectionState * state, const unsigned char * remote_public_key,
+ const size_t public_key_length, const unsigned char * local_private_key,
+ const size_t private_key_length);
/**
* @brief
@@ -168,7 +172,11 @@
System::Layer * mSystemLayer;
Inet::InetLayer * mInetLayer;
- SecureSessionMgr * mDeviceCon;
+
+ // TODO: CHIPDeviceController assumes a single device connection, where as
+ // session manager handles multiple connections. Need to finalize design on this
+ // as otherwise a single connection may end up processing data from other peers.
+ SecureSessionMgr * mSessionManager;
ConnectionState mConState;
void * mAppReqState;
@@ -180,6 +188,7 @@
} mOnComplete;
ErrorHandler mOnError;
+ NewConnectionHandler mOnNewConnection;
System::PacketBuffer * mCurReqMsg;
NodeId mLocalDeviceId;
@@ -191,8 +200,10 @@
void ClearRequestState();
void ClearOpState();
- static void OnReceiveMessage(const MessageHeader & header, const IPPacketInfo & pktInfo, System::PacketBuffer * msgBuf,
- ChipDeviceController * controller);
+ static void OnNewConnection(Transport::PeerConnectionState * state, ChipDeviceController * controller);
+
+ static void OnReceiveMessage(const MessageHeader & header, Transport::PeerConnectionState * state,
+ System::PacketBuffer * msgBuf, ChipDeviceController * controller);
};
} // namespace DeviceController
diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp
index ca776f0..21053d5 100644
--- a/src/transport/SecureSession.cpp
+++ b/src/transport/SecureSession.cpp
@@ -31,6 +31,12 @@
namespace chip {
+namespace {
+
+const char * kManualKeyExchangeChannelInfo = "Manual Key Exchanged Channel";
+
+} // namespace
+
using namespace Crypto;
SecureSession::SecureSession() : mKeyAvailable(false), mNextIV(0) {}
@@ -116,4 +122,18 @@
return error;
}
+CHIP_ERROR SecureSession::TemporaryManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
+ const unsigned char * local_private_key, const size_t private_key_length)
+{
+ CHIP_ERROR err = CHIP_NO_ERROR;
+ size_t info_len = strlen(kManualKeyExchangeChannelInfo);
+
+ err = Init(remote_public_key, public_key_length, local_private_key, private_key_length, NULL, 0,
+ (const unsigned char *) kManualKeyExchangeChannelInfo, info_len);
+ SuccessOrExit(err);
+
+exit:
+ return err;
+}
+
} // namespace chip
diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h
index 609d6e6..65ee9d7 100644
--- a/src/transport/SecureSession.h
+++ b/src/transport/SecureSession.h
@@ -93,6 +93,22 @@
*/
void Reset(void);
+ /**
+ * @brief
+ * The keypair for the secure channel. This is a utility function that will be used
+ * until we have automatic key exchange in place. The function is useful only for
+ * example applications for now. It will eventually be removed.
+ *
+ * @param remote_public_key A pointer to peer's public key
+ * @param public_key_length Length of remote_public_key
+ * @param local_private_key A pointer to local private key
+ * @param private_key_length Length of local_private_key
+ * @return CHIP_ERROR The result of key derivation
+ */
+ [[deprecated("Available until actual key exchange is implemented")]] CHIP_ERROR
+ TemporaryManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
+ const unsigned char * local_private_key, const size_t private_key_length);
+
private:
static constexpr size_t kAES_CCM128_Key_Length = 16;
diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp
index 410ebf6..35441ca 100644
--- a/src/transport/SecureSessionMgr.cpp
+++ b/src/transport/SecureSessionMgr.cpp
@@ -33,13 +33,15 @@
namespace chip {
+using Transport::PeerAddress;
+using Transport::PeerConnectionState;
+
// Maximum length of application data that can be encrypted as one block.
// The limit is derived from IPv6 MTU (1280 bytes) - expected header overheads.
// This limit would need additional reviews once we have formalized Secure Transport header.
-static const size_t kMax_SecureSDU_Length = 1024;
-static const char * kManualKeyExchangeChannelInfo = "Manual Key Exchanged Channel";
+static const size_t kMax_SecureSDU_Length = 1024;
-SecureSessionMgr::SecureSessionMgr() : mConnectionState(Transport::PeerAddress::Uninitialized()), mState(State::kNotReady)
+SecureSessionMgr::SecureSessionMgr() : mState(State::kNotReady)
{
OnMessageReceived = NULL;
}
@@ -52,7 +54,7 @@
err = mTransport.Init(inet, listenParams);
SuccessOrExit(err);
- mTransport.SetMessageReceiveHandler(HandleDataReceived, this);
+ mTransport.SetMessageReceiveHandler(HandleUdpDataReceived, this);
mState = State::kInitialized;
mLocalNodeId = localNodeId;
@@ -62,30 +64,24 @@
CHIP_ERROR SecureSessionMgr::Connect(NodeId peerNodeId, const Transport::PeerAddress & peerAddress)
{
- CHIP_ERROR err = CHIP_NO_ERROR;
+ CHIP_ERROR err = CHIP_NO_ERROR;
+ PeerConnectionState * state = nullptr;
VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE);
- mConnectionState.SetPeerNodeId(peerNodeId);
- mConnectionState.SetPeerAddress(peerAddress);
-
- mState = State::kConnected;
-
-exit:
- return err;
-}
-
-CHIP_ERROR SecureSessionMgr::ManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
- const unsigned char * local_private_key, const size_t private_key_length)
-{
- CHIP_ERROR err = CHIP_NO_ERROR;
- size_t info_len = strlen(kManualKeyExchangeChannelInfo);
-
- VerifyOrExit(mState == State::kConnected || mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE);
- err = mConnectionState.GetSecureSession().Init(remote_public_key, public_key_length, local_private_key, private_key_length,
- NULL, 0, (const unsigned char *) kManualKeyExchangeChannelInfo, info_len);
+ err = mPeerConnections.CreateNewPeerConnectionState(peerAddress, &state);
SuccessOrExit(err);
- mState = State::kSecureConnected;
+
+ state->SetPeerNodeId(peerNodeId);
+
+ // TODO:
+ // - on new connection for other transports may not be inline.
+ // - determine if logic is required to prevent multiple connections to the
+ // same node or if some connection reuse is required.
+ if (OnNewConnection)
+ {
+ OnNewConnection(state, mNewConnectionArgument);
+ }
exit:
return err;
@@ -93,40 +89,42 @@
CHIP_ERROR SecureSessionMgr::SendMessage(NodeId peerNodeId, System::PacketBuffer * msgBuf)
{
- CHIP_ERROR err = CHIP_NO_ERROR;
+ CHIP_ERROR err = CHIP_NO_ERROR;
+ PeerConnectionState * state = nullptr;
- VerifyOrExit(StateAllowsSend(), err = CHIP_ERROR_INCORRECT_STATE);
+ VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE);
VerifyOrExit(msgBuf != NULL, err = CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrExit(msgBuf->Next() == NULL, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
VerifyOrExit(msgBuf->TotalLength() < kMax_SecureSDU_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH);
+ // Find an active connection to the specified peer node
+ VerifyOrExit(mPeerConnections.FindPeerConnectionState(peerNodeId, &state), err = CHIP_ERROR_INVALID_DESTINATION_NODE_ID);
+
{
uint8_t * data = msgBuf->Start();
MessageHeader header;
- err = mConnectionState.GetSecureSession().Encrypt(data, msgBuf->TotalLength(), data, header);
+ err = state->GetSecureSession().Encrypt(data, msgBuf->TotalLength(), data, header);
SuccessOrExit(err);
- ChipLogProgress(Inet, "Secure transport transmitting msg %u after encryption", mConnectionState.GetSendMessageIndex());
+ ChipLogProgress(Inet, "Secure transport transmitting msg %u after encryption", state->GetSendMessageIndex());
header
.SetSourceNodeId(mLocalNodeId) //
.SetDestinationNodeId(peerNodeId) //
- .SetMessageId(mConnectionState.GetSendMessageIndex());
+ .SetMessageId(state->GetSendMessageIndex());
- err = mTransport.SendMessage(header, mConnectionState.GetPeerAddress(), msgBuf);
+ err = mTransport.SendMessage(header, state->GetPeerAddress(), msgBuf);
msgBuf = NULL;
}
SuccessOrExit(err);
-
- mConnectionState.IncrementSendMessageIndex();
+ state->IncrementSendMessageIndex();
exit:
if (msgBuf != NULL)
{
- ChipLogProgress(Inet, "Secure transport failed to encrypt msg %u: %s", mConnectionState.GetSendMessageIndex(),
- ErrorStr(err));
+ ChipLogProgress(Inet, "Secure transport failed to encrypt msg %u: %s", state->GetSendMessageIndex(), ErrorStr(err));
PacketBuffer::Free(msgBuf);
msgBuf = NULL;
}
@@ -134,21 +132,47 @@
return err;
}
-void SecureSessionMgr::HandleDataReceived(const MessageHeader & header, const IPPacketInfo & pktInfo, System::PacketBuffer * msg,
- SecureSessionMgr * connection)
+CHIP_ERROR SecureSessionMgr::AllocateNewConnection(const MessageHeader & header, const PeerAddress & address,
+ Transport::PeerConnectionState ** state)
{
- CHIP_ERROR err = CHIP_ERROR_INVALID_ARGUMENT;
+ CHIP_ERROR err = CHIP_NO_ERROR;
- System::PacketBuffer * origMsg = nullptr;
+ err = mPeerConnections.CreateNewPeerConnectionState(address, state);
+ SuccessOrExit(err);
- // TODO: actual key exchange should happen here
- if (!connection->StateAllowsReceive())
+ if (header.GetSourceNodeId().HasValue())
{
- if (connection->OnNewConnection)
+ (*state)->SetPeerNodeId(header.GetSourceNodeId().Value());
+ }
+
+ if (OnNewConnection)
+ {
+ OnNewConnection(*state, mNewConnectionArgument);
+ }
+
+exit:
+ return err;
+}
+void SecureSessionMgr::HandleUdpDataReceived(const MessageHeader & header, const IPPacketInfo & pktInfo, System::PacketBuffer * msg,
+ SecureSessionMgr * connection)
+
+{
+ CHIP_ERROR err = CHIP_NO_ERROR;
+ System::PacketBuffer * origMsg = nullptr;
+ PeerConnectionState * state = nullptr;
+
+ VerifyOrExit(msg != nullptr, ChipLogError(Inet, "Secure transport received NULL packet, discarding"));
+
+ {
+ PeerAddress peerAddress = PeerAddress::UDP(pktInfo.SrcAddress, pktInfo.SrcPort);
+
+ if (!connection->mPeerConnections.FindPeerConnectionState(peerAddress, &state))
{
- connection->OnNewConnection(header, pktInfo, connection->mNewConnectionArgument);
+ ChipLogProgress(Inet, "New peer connection received.");
+
+ err = connection->AllocateNewConnection(header, peerAddress, &state);
+ SuccessOrExit(err);
}
- err = CHIP_NO_ERROR;
}
if (!connection->StateAllowsReceive())
@@ -156,8 +180,6 @@
ExitNow(ChipLogProgress(Inet, "Secure transport failed: state does not allow receive"));
}
- VerifyOrExit(msg != nullptr, ChipLogError(Inet, "Secure transport received NULL packet, discarding"));
-
// TODO this is where messages should be decoded
{
uint8_t * data = msg->Start();
@@ -172,12 +194,12 @@
#endif
plainText = msg->Start();
- err = connection->mConnectionState.GetSecureSession().Decrypt(data, len, plainText, header);
+ err = state->GetSecureSession().Decrypt(data, len, plainText, header);
VerifyOrExit(err == CHIP_NO_ERROR, ChipLogProgress(Inet, "Secure transport failed to decrypt msg: err %d", err));
if (connection->OnMessageReceived)
{
- connection->OnMessageReceived(header, pktInfo, msg, connection->mMessageReceivedArgument);
+ connection->OnMessageReceived(header, state, msg, connection->mMessageReceivedArgument);
msg = nullptr;
}
}
diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h
index e54488d..11d9e35 100644
--- a/src/transport/SecureSessionMgr.h
+++ b/src/transport/SecureSessionMgr.h
@@ -32,7 +32,7 @@
#include <core/ReferenceCounted.h>
#include <inet/IPAddress.h>
#include <inet/IPEndPointBasis.h>
-#include <transport/PeerConnectionState.h>
+#include <transport/PeerConnections.h>
#include <transport/SecureSession.h>
#include <transport/UDP.h>
@@ -48,10 +48,8 @@
*/
enum class State
{
- kNotReady, /**< State before initialization. */
- kInitialized, /**< State when the object is ready to connect. */
- kConnected, /**< State when the remte peer is connected. */
- kSecureConnected, /**< State when the security of the connection has been established. */
+ kNotReady, /**< State before initialization. */
+ kInitialized, /**< State when the object is ready connect to other peers. */
};
/**
@@ -68,21 +66,6 @@
CHIP_ERROR Init(NodeId localNodeId, Inet::InetLayer * inet, const Transport::UdpListenParameters & listenParams);
/**
- * @brief
- * The keypair for the secure channel. This is a utility function that will be used
- * until we have automatic key exchange in place. The function is useful only for
- * example applications for now. It will eventually be removed.
- *
- * @param remote_public_key A pointer to peer's public key
- * @param public_key_length Length of remote_public_key
- * @param local_private_key A pointer to local private key
- * @param private_key_length Length of local_private_key
- * @return CHIP_ERROR The result of key derivation
- */
- CHIP_ERROR ManualKeyExchange(const unsigned char * remote_public_key, const size_t public_key_length,
- const unsigned char * local_private_key, const size_t private_key_length);
-
- /**
* Establishes a connection to the given peer node.
*
* A connection needs to be established before SendMessage can be called.
@@ -110,7 +93,8 @@
*
*/
template <class T>
- void SetMessageReceiveHandler(void (*handler)(const MessageHeader &, const Inet::IPPacketInfo &, System::PacketBuffer *, T *),
+ void SetMessageReceiveHandler(void (*handler)(const MessageHeader &, Transport::PeerConnectionState *, System::PacketBuffer *,
+ T *),
T * param)
{
mMessageReceivedArgument = param;
@@ -136,36 +120,19 @@
*
*/
template <class T>
- void SetNewConnectionHandler(void (*handler)(const MessageHeader &, const Inet::IPPacketInfo &, T *), T * param)
+ void SetNewConnectionHandler(void (*handler)(Transport::PeerConnectionState *, T *), T * param)
{
mNewConnectionArgument = param;
OnNewConnection = reinterpret_cast<NewConnectionHandler>(handler);
}
- /**
- * TEMPORARY method for current connection until multi-session handling support
- * is added to this class.
- */
- NodeId GetPeerNodeId() const { return mConnectionState.GetPeerNodeId(); };
-
private:
- // TODO: settings below are single-transport and single-peer. Long term
- // intent for the class is to do session management and there will be
- // multiple transports and multiple peers supported.
- // Expected changes:
- // - Transports including UDP, TCP, BLE as needed
- // - Multiple peers with separate states (encryption settings and handshake state)
- // - Global state is only a ready/notready and connection state is deferred into
- // individual connections.
-
+ // TODO: add support for multiple transports (TCP, BLE to be added)
Transport::UDP mTransport;
- Transport::PeerConnectionState mConnectionState;
- NodeId mLocalNodeId;
- State mState;
-
- bool StateAllowsSend(void) const { return mState != State::kNotReady; }
- bool StateAllowsReceive(void) const { return mState == State::kSecureConnected; }
+ NodeId mLocalNodeId; //< Id of the current node
+ Transport::PeerConnections mPeerConnections; //< Active connections to other peers
+ State mState; //< Initialization state of the object
/**
* This function is the application callback that is invoked when a message is received over a
@@ -173,7 +140,7 @@
*
* @param[in] msgBuf A pointer to the PacketBuffer object holding the message.
*/
- typedef void (*MessageReceiveHandler)(const MessageHeader & header, const Inet::IPPacketInfo & source,
+ typedef void (*MessageReceiveHandler)(const MessageHeader & header, Transport::PeerConnectionState * state,
System::PacketBuffer * msgBuf, void * param);
MessageReceiveHandler OnMessageReceived = nullptr; ///< Callback on message receiving
@@ -183,13 +150,28 @@
ReceiveErrorHandler OnReceiveError = nullptr; ///< Callback on error in message receiving
- typedef void (*NewConnectionHandler)(const MessageHeader & header, const Inet::IPPacketInfo & source, void * param);
+ typedef void (*NewConnectionHandler)(Transport::PeerConnectionState * state, void * param);
NewConnectionHandler OnNewConnection = nullptr; ///< Callback for new connection received
void * mNewConnectionArgument = nullptr; ///< Argument for callback
- static void HandleDataReceived(const MessageHeader & header, const Inet::IPPacketInfo & source, System::PacketBuffer * msgBuf,
- SecureSessionMgr * transport);
+ /**
+ * Allocates a new connection for the given source.
+ *
+ * @param header The header that was received when the connection was established
+ * @param address Where the connection originated from
+ * @param state [out] the newly allocated connection state for the connection
+ */
+ CHIP_ERROR AllocateNewConnection(const MessageHeader & header, const Transport::PeerAddress & address,
+ Transport::PeerConnectionState ** state);
+
+ /**
+ * Handle UDP data receiving. Each transport has separate data receiving as active sessions
+ * follow data receiving channels.
+ *
+ */
+ static void HandleUdpDataReceived(const MessageHeader & header, const Inet::IPPacketInfo & source,
+ System::PacketBuffer * msgBuf, SecureSessionMgr * transport);
};
} // namespace chip
diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp
index 88b449b..0156b90 100644
--- a/src/transport/tests/TestSecureSessionMgr.cpp
+++ b/src/transport/tests/TestSecureSessionMgr.cpp
@@ -62,11 +62,12 @@
int ReceiveHandlerCallCount = 0;
-static void MessageReceiveHandler(const MessageHeader & header, const Inet::IPPacketInfo & source, System::PacketBuffer * msgBuf,
- nlTestSuite * inSuite)
+static void MessageReceiveHandler(const MessageHeader & header, Transport::PeerConnectionState * state,
+ System::PacketBuffer * msgBuf, nlTestSuite * inSuite)
{
NL_TEST_ASSERT(inSuite, header.GetSourceNodeId() == Optional<NodeId>::Value(kSourceNodeId));
NL_TEST_ASSERT(inSuite, header.GetDestinationNodeId() == Optional<NodeId>::Value(kDestinationNodeId));
+ NL_TEST_ASSERT(inSuite, state->GetPeerNodeId() == kDestinationNodeId);
size_t data_len = msgBuf->DataLength();
@@ -74,7 +75,19 @@
NL_TEST_ASSERT(inSuite, compare == 0);
ReceiveHandlerCallCount++;
-};
+}
+
+int NewConnectionHandlerCallCount = 0;
+static void NewConnectionHandler(Transport::PeerConnectionState * state, nlTestSuite * inSuite)
+{
+ CHIP_ERROR err;
+
+ NewConnectionHandlerCallCount++;
+
+ err = state->GetSecureSession().TemporaryManualKeyExchange(remote_public_key, sizeof(remote_public_key), local_private_key,
+ sizeof(local_private_key));
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+}
static void DriveIO(TestContext & ctx)
{
@@ -162,14 +175,14 @@
err = conn.Init(kSourceNodeId, &ctx.mInetLayer, Transport::UdpListenParameters().SetAddressType(addr.Type()));
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
- err = conn.Connect(kDestinationNodeId, Transport::PeerAddress::UDP(addr));
- NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
-
- err = conn.ManualKeyExchange(remote_public_key, sizeof(remote_public_key), local_private_key, sizeof(local_private_key));
- NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
-
+ conn.SetNewConnectionHandler(NewConnectionHandler, inSuite);
conn.SetMessageReceiveHandler(MessageReceiveHandler, inSuite);
+ NewConnectionHandlerCallCount = 0;
+ err = conn.Connect(kDestinationNodeId, Transport::PeerAddress::UDP(addr));
+ NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
+ NL_TEST_ASSERT(inSuite, NewConnectionHandlerCallCount == 1);
+
// Should be able to send a message to itself by just calling send.
ReceiveHandlerCallCount = 0;
err = conn.SendMessage(kDestinationNodeId, buffer);