drivers: modem: ublox-sara-r4: add TLS offload support

Currently it's able to connect to google iot. All other use cases are
untested.

Signed-off-by: Johann Tael <jntael@gmail.com>
diff --git a/drivers/modem/modem_socket.h b/drivers/modem/modem_socket.h
index e0309f1..ab1a06f 100644
--- a/drivers/modem/modem_socket.h
+++ b/drivers/modem/modem_socket.h
@@ -26,7 +26,7 @@
 __net_socket struct modem_socket {
 	sa_family_t family;
 	enum net_sock_type type;
-	enum net_ip_protocol ip_proto;
+	int ip_proto;
 	struct sockaddr src;
 	struct sockaddr dst;
 	int id;
diff --git a/drivers/modem/ublox-sara-r4.c b/drivers/modem/ublox-sara-r4.c
index 32bbf6b..8622cf6 100644
--- a/drivers/modem/ublox-sara-r4.c
+++ b/drivers/modem/ublox-sara-r4.c
@@ -34,6 +34,12 @@
 #define CONFIG_MODEM_UBLOX_SARA_R4_MANUAL_MCCMNO ""
 #endif
 
+
+#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS)
+#include "tls_internal.h"
+#include <net/tls_credentials.h>
+#endif
+
 /* pin settings */
 enum mdm_control_pins {
 	MDM_POWER = 0,
@@ -93,7 +99,7 @@
 #define MDM_IMEI_LENGTH			16
 #define MDM_IMSI_LENGTH			16
 #define MDM_APN_LENGTH			32
-
+#define MDM_MAX_CERT_LENGTH		8192
 #if defined(CONFIG_MODEM_UBLOX_SARA_AUTODETECT_VARIANT)
 #define MDM_VARIANT_UBLOX_R4 4
 #define MDM_VARIANT_UBLOX_U2 2
@@ -437,6 +443,71 @@
 	return mdata.sock_written;
 }
 
+#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS)
+/* send binary data via the +USO[ST/WR] commands */
+static ssize_t send_cert(struct modem_socket *sock,
+			 struct modem_cmd *handler_cmds,
+			 size_t handler_cmds_len,
+			 const char *cert_data, size_t cert_len,
+			 int cert_type)
+{
+	int ret;
+	char *filename = "ca";
+	char send_buf[sizeof("AT+USECMNG=#,#,!####!,####\r\n")];
+
+	/* TODO support other cert types as well */
+	if (cert_type != 0) {
+		return -EINVAL;
+	}
+
+	if (!sock) {
+		return -EINVAL;
+	}
+
+	__ASSERT_NO_MSG(cert_len <= MDM_MAX_CERT_LENGTH);
+
+	snprintk(send_buf, sizeof(send_buf),
+		 "AT+USECMNG=0,%d,\"%s\",%d", cert_type, filename, cert_len);
+
+	k_sem_take(&mdata.cmd_handler_data.sem_tx_lock, K_FOREVER);
+
+	ret = modem_cmd_send_nolock(&mctx.iface, &mctx.cmd_handler,
+				    NULL, 0U, send_buf, NULL, K_NO_WAIT);
+	if (ret < 0) {
+		goto exit;
+	}
+
+	/* set command handlers */
+	ret = modem_cmd_handler_update_cmds(&mdata.cmd_handler_data,
+					    handler_cmds, handler_cmds_len,
+					    true);
+	if (ret < 0) {
+		goto exit;
+	}
+
+	/* slight pause per spec so that @ prompt is received */
+	k_sleep(MDM_PROMPT_CMD_DELAY);
+	mctx.iface.write(&mctx.iface, cert_data, cert_len);
+
+	k_sem_reset(&mdata.sem_response);
+	ret = k_sem_take(&mdata.sem_response, K_MSEC(1000));
+
+	if (ret == 0) {
+		ret = modem_cmd_handler_get_error(&mdata.cmd_handler_data);
+	} else if (ret == -EAGAIN) {
+		ret = -ETIMEDOUT;
+	}
+
+exit:
+	/* unset handler commands and ignore any errors */
+	(void)modem_cmd_handler_update_cmds(&mdata.cmd_handler_data,
+					    NULL, 0U, false);
+	k_sem_give(&mdata.cmd_handler_data.sem_tx_lock);
+
+	return ret;
+}
+#endif
+
 /*
  * Modem Response Command Handlers
  */
@@ -643,6 +714,15 @@
 	return 0;
 }
 
+#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS)
+/* Handler: +USECMNG: 0,<type>[0],<internal_name>[1],<md5_string>[2] */
+MODEM_CMD_DEFINE(on_cmd_cert_write)
+{
+	LOG_DBG("cert md5: %s", log_strdup(argv[2]));
+	return 0;
+}
+#endif
+
 /* Common code for +USOR[D|F]: "<data>" */
 static int on_cmd_sockread_common(int socket_id,
 				  struct modem_cmd_handler_data *data,
@@ -1252,14 +1332,57 @@
 			     &cmd, 1U, buf,
 			     &mdata.sem_response, MDM_CMD_TIMEOUT);
 	if (ret < 0) {
-		LOG_ERR("%s ret:%d", log_strdup(buf), ret);
-		modem_socket_put(&mdata.socket_config, sock->sock_fd);
-		errno = -ret;
-		return -1;
+		goto error;
+	}
+
+	if (sock->ip_proto == IPPROTO_TLS_1_2) {
+		char buf[sizeof("AT+USECPRF=#,#,#######\r")];
+
+		/* Enable socket security */
+		snprintk(buf, sizeof(buf), "AT+USOSEC=%d,1,%d", sock->id, sock->id);
+		ret = modem_cmd_send(&mctx.iface, &mctx.cmd_handler, NULL, 0U, buf,
+				     &mdata.sem_response, MDM_CMD_TIMEOUT);
+		if (ret < 0) {
+			goto error;
+		}
+		/* Reset the security profile */
+		snprintk(buf, sizeof(buf), "AT+USECPRF=%d", sock->id);
+		ret = modem_cmd_send(&mctx.iface, &mctx.cmd_handler, NULL, 0U, buf,
+				     &mdata.sem_response, MDM_CMD_TIMEOUT);
+		if (ret < 0) {
+			goto error;
+		}
+		/* Validate server cert against the CA.  */
+		snprintk(buf, sizeof(buf), "AT+USECPRF=%d,0,1", sock->id);
+		ret = modem_cmd_send(&mctx.iface, &mctx.cmd_handler, NULL, 0U, buf,
+				     &mdata.sem_response, MDM_CMD_TIMEOUT);
+		if (ret < 0) {
+			goto error;
+		}
+		/* Use TLSv1.2 only */
+		snprintk(buf, sizeof(buf), "AT+USECPRF=%d,1,3", sock->id);
+		ret = modem_cmd_send(&mctx.iface, &mctx.cmd_handler, NULL, 0U, buf,
+				     &mdata.sem_response, MDM_CMD_TIMEOUT);
+		if (ret < 0) {
+			goto error;
+		}
+		/* Set root CA filename */
+		snprintk(buf, sizeof(buf), "AT+USECPRF=%d,3,\"ca\"", sock->id);
+		ret = modem_cmd_send(&mctx.iface, &mctx.cmd_handler, NULL, 0U, buf,
+				     &mdata.sem_response, MDM_CMD_TIMEOUT);
+		if (ret < 0) {
+			goto error;
+		}
 	}
 
 	errno = 0;
 	return 0;
+
+error:
+	LOG_ERR("%s ret:%d", log_strdup(buf), ret);
+	modem_socket_put(&mdata.socket_config, sock->sock_fd);
+	errno = -ret;
+	return -1;
 }
 
 /*
@@ -1639,6 +1762,92 @@
 	return (ssize_t)sent;
 }
 
+#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS)
+static int map_credentials(struct modem_socket *sock, const void *optval, socklen_t optlen)
+{
+	sec_tag_t *sec_tags = (sec_tag_t *)optval;
+	int ret = 0;
+	int tags_len;
+	sec_tag_t tag;
+	int id;
+	int i;
+	struct tls_credential *cert;
+
+	if ((optlen % sizeof(sec_tag_t)) != 0 || (optlen == 0)) {
+		return -EINVAL;
+	}
+
+	tags_len = optlen / sizeof(sec_tag_t);
+	/* For each tag, retrieve the credentials value and type: */
+	for (i = 0; i < tags_len; i++) {
+		tag = sec_tags[i];
+		cert = credential_next_get(tag, NULL);
+		while (cert != NULL) {
+			switch (cert->type) {
+			case TLS_CREDENTIAL_CA_CERTIFICATE:
+				id = 0;
+				break;
+			case TLS_CREDENTIAL_NONE:
+			case TLS_CREDENTIAL_PSK:
+			case TLS_CREDENTIAL_PSK_ID:
+			default:
+				/* Not handled */
+				return -EINVAL;
+			}
+			struct modem_cmd cmd[] = {
+				MODEM_CMD("+USECMNG: ", on_cmd_cert_write, 3U, ","),
+			};
+			ret = send_cert(sock, cmd, 1, cert->buf, cert->len, id);
+			if (ret < 0) {
+				return ret;
+			}
+
+			cert = credential_next_get(tag, cert);
+		}
+	}
+
+	return 0;
+}
+#else
+static int map_credentials(struct modem_socket *sock, const void *optval, socklen_t optlen)
+{
+	return -EINVAL;
+}
+#endif
+
+static int offload_setsockopt(void *obj, int level, int optname,
+			      const void *optval, socklen_t optlen)
+{
+	struct modem_socket *sock = (struct modem_socket *)obj;
+
+	int ret;
+
+	if (IS_ENABLED(CONFIG_NET_SOCKETS_SOCKOPT_TLS) && level == SOL_TLS) {
+		switch (optname) {
+		case TLS_SEC_TAG_LIST:
+			ret = map_credentials(sock, optval, optlen);
+			break;
+		case TLS_HOSTNAME:
+			LOG_WRN("TLS_HOSTNAME option is not supported");
+			return -EINVAL;
+		case TLS_PEER_VERIFY:
+			if (*(uint32_t *)optval != TLS_PEER_VERIFY_REQUIRED) {
+				LOG_WRN("Disabling peer verification is not supported");
+				return -EINVAL;
+			}
+			ret = 0;
+			break;
+		default:
+			return -EINVAL;
+		}
+	} else {
+		return -EINVAL;
+	}
+
+	return ret;
+}
+
+
 static const struct socket_op_vtable offload_socket_fd_op_vtable = {
 	.fd_vtable = {
 		.read = offload_read,
@@ -1654,7 +1863,7 @@
 	.accept = NULL,
 	.sendmsg = offload_sendmsg,
 	.getsockopt = NULL,
-	.setsockopt = NULL,
+	.setsockopt = offload_setsockopt,
 };
 
 static bool offload_is_supported(int family, int type, int proto)