/*
 * Copyright (c) 2003-2016
 * Distributed Systems Software.  All rights reserved.
 * See the file LICENSE for redistribution information.
 */

/*
 * SSL support library
 * Uses the OpenSSL library.
 */

#ifndef lint
static const char copyright[] =
"Copyright (c) 2003-2016\n\
Distributed Systems Software.  All rights reserved.";
static const char revid[] =
  "$Id: ssllib.c 2912 2016-10-18 19:54:07Z brachman $";
#endif

#include "dacs_ssl.h"

#include <sys/ioctl.h>

static char *log_module_name = "ssllib";

extern Ssl_global_conf global_conf;

/*
 * Seed the pseudo random number generator from file RFILE.
 * Return -1 on error, 0 otherwise.
 */
static int
seed_prng(char *rfile)
{

  if (RAND_load_file(rfile, RAND_SEED_BYTES) != RAND_SEED_BYTES)
	return(-1);

  return(0);
}

/*
 * The callback function for SSL_CTX_set_verify.
 * Parameter OK is non-zero iff verification was successful; STORE holds the
 * certificate being verified and the verification status.
 * Return non-zero if we want to accept the certificate (regardless of OK),
 * 0 otherwise.
 */
int
ssl_verify_callback(int ok, X509_STORE_CTX *ctx)
{
#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL)
  X509_STORE *store = X509_STORE_CTX_get0_store(ctx);
#else
  X509_STORE *store = ctx->ctx;
#endif
  Ssl_conf *conf;

  /*
   * Why won't OpenSSL let us pass user data here (e.g., Ssl_conf) so that
   * we can print this as we require?
   * There's X509_STORE_CTX_[gs]et_app_data(), but that is utterly
   * undocumented...
   *
   * So we use a non-thread safe kludge... with a sanity check.
   */
  if (store == global_conf.store)
	conf = global_conf.conf;
  else
	conf = NULL;

  if (!ok && conf != NULL) {
	int depth, err;
	char data[256];
	X509 *cert;

	depth = X509_STORE_CTX_get_error_depth(ctx);
	err = X509_STORE_CTX_get_error(ctx);
	cert = X509_STORE_CTX_get_current_cert(ctx);

	if (conf->verbose_flag) {
	  log_msg((LOG_TRACE_LEVEL, "certificate verification error:"));
	  log_msg((LOG_TRACE_LEVEL, "    depth=%d", depth));
	  X509_NAME_oneline(X509_get_issuer_name(cert), data, sizeof(data));
	  log_msg((LOG_TRACE_LEVEL, "    issuer=\"%s\"", data));
	  X509_NAME_oneline(X509_get_subject_name(cert), data, sizeof(data));
	  log_msg((LOG_TRACE_LEVEL, "    subject=\"%s\"", data));
	  log_msg((LOG_TRACE_LEVEL, "    err=\"%d:%s\"",
			   err, X509_verify_cert_error_string(err)));
	}

	if (conf->verify_depth >= depth) {
	  ok = 1;
	  conf->verify_error = X509_V_OK;
	  SSL_set_verify_result(conf->ssl, X509_V_OK);
	}
	else if (err == X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN
			 && conf->verify_allow_self_signed) {
	  ok = 1;
	  conf->verify_error = X509_V_OK;
	  SSL_set_verify_result(conf->ssl, X509_V_OK);
	}
	else {
	  conf->verify_error = X509_V_ERR_CERT_CHAIN_TOO_LONG;
	  SSL_set_verify_result(conf->ssl, X509_V_ERR_CERT_CHAIN_TOO_LONG);
	}
  }
  else
	SSL_set_verify_result(conf->ssl, X509_V_OK);
 
  return(ok);
}

static int
peer_match(Ssl_conf *conf, char *peer, char **errbuf)
{
  int i, st;
  Ssl_peer_match **smb;

  if (conf->peer_match_vec == NULL)
	return(0);

  smb = (Ssl_peer_match **) dsvec_base(conf->peer_match_vec);
  for (i = 0; smb[i] != NULL; i++) {
	st = regexec(smb[i]->regex, peer, 0, NULL, 0);
	if (conf->verbose_flag) {
	  log_msg((LOG_TRACE_LEVEL, "matching regex \"%s\" : \"%s\" %s",
			   smb[i]->regex_str, peer,
			   (st == 0) ? "succeeds"
			   : ((st == REG_NOMATCH) ? "fails" : "error!")));
	}
	if (st == 0)
	  return(1);
	else if (st != REG_NOMATCH) {
	  if (errbuf != NULL) {
		char err[128];

		regerror(st, smb[i]->regex, err, sizeof(err));
		*errbuf = strdup(err);
	  }
	  return(-1);
	}
  }

  return(0);
}

/*
 * After the connection has been established, verify that the name of the
 * peer (PEER: if client, who we connected to; if server, who connected to us)
 * matches either the subjectAltName or the commonName
 * in the peer's cert, following RFC 2818:
 *   If the hostname is available, the client MUST check it against the
 *   server's identity as presented in the server's Certificate message,
 *   in order to prevent man-in-the-middle attacks.
 *   If a subjectAltName extension of type dNSName is present, that MUST
 *   be used as the identity. Otherwise, the (most specific) Common Name
 *   field in the Subject field of the certificate MUST be used.
 *
 * If a regex match has been requested, try that first; if it fails, try
 * an exact, case insensitive string match.
 * The regex will be matched against the value of a dNSName field with
 * "DNS:" prepended (e.g., "DNS:amd2.dss.ca").
 * It will be matched against the value of an iPAddress field with
 * "IP Address:" prepended.
 *
 * Return an OpenSSL error code.
 *
 * XXX This is not yet a complete implementation of RFC 2818 and RFC 2459:
 *   o doesn't handle wildcard matches in the subjectAltName/dNSName
 * See RFC 2459 4.2.1.7 and RFC 2818 3.1
 */
long
ssl_post_connection_check(SSL *ssl, char *peer, Ssl_conf *conf)
{
  int extcount, ok, st;
  long err;
  char *errmsg;
  X509 *cert;
  X509_NAME *subj;
 
  ok = 0;

  /*
   * This can return NULL if anonymous ciphers are enabled or if the server
   * doesn't require a client certificate.
   */
  if ((cert = SSL_get_peer_certificate(ssl)) == NULL) {
	if (conf->verify_type == SSL_VERIFY_NONE)
	  return(X509_V_OK);
	log_msg((LOG_ERROR_LEVEL, "Cannot get peer cert"));
	goto err_occurred;
  }

  if (peer == NULL) {
	if (conf->verify_type == SSL_VERIFY_NONE)
	  return(X509_V_OK);
	log_msg((LOG_ERROR_LEVEL, "peer argument is NULL"));
	goto err_occurred;
  }

  if ((extcount = X509_get_ext_count(cert)) > 0) {
	int i;
 
	/*
	 * There are X.509v3 extensions present in the peer's cert.
	 * Look through them for a subjectAltName.
	 */
	for (i = 0;  i < extcount;  i++) {
	  char *extstr, *san;
	  X509_EXTENSION *ext;
 
	  ext = X509_get_ext(cert, i);
	  extstr
		= (char *) OBJ_nid2sn(OBJ_obj2nid(X509_EXTENSION_get_object(ext)));
 
	  if (conf->verbose_flag > 1)
		log_msg((LOG_TRACE_LEVEL, "Extension field: \"%s\"", extstr));

	  if (streq(extstr, "subjectAltName")) {
		int j;
#if (OPENSSL_VERSION_NUMBER >= 0x0090800fL)
		const unsigned char *data;
#else
		unsigned char *data;
#endif
		STACK_OF(CONF_VALUE) *val;
		CONF_VALUE *nval;
		const X509V3_EXT_METHOD *meth;
		void *ext_str;
 
		ext_str = NULL;
		if ((meth = X509V3_EXT_get(ext)) == NULL)
		  break;

#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL)
		data = NULL;
		ext_str = X509V3_EXT_d2i(ext);
#elif (OPENSSL_VERSION_NUMBER > 0x00907000L)
		data = ext->value->data;
		if (meth->it)
		  ext_str = ASN1_item_d2i(NULL, &data, ext->value->length,
								  ASN1_ITEM_ptr(meth->it));
		else
		  ext_str = meth->d2i(NULL, &data, ext->value->length);
#else
		data = ext->value->data;
		ext_str = meth->d2i(NULL, &data, ext->value->length);
#endif
		val = meth->i2v(meth, ext_str, NULL);

		for (j = 0;  j < sk_CONF_VALUE_num(val);  j++) {
		  nval = sk_CONF_VALUE_value(val, j);
		  san = ds_xprintf("%s:%s", nval->name, nval->value);
		  if (conf->verbose_flag)
			log_msg((LOG_TRACE_LEVEL, "subjectAltName=\"%s\"", san));

		  if ((st = peer_match(conf, san, &errmsg)) == 1) {
			ok = 1;
			break;
		  }
		  else if (st == -1) {
			log_msg((LOG_ERROR_LEVEL, "regex exec error: %s", errmsg));
			break;
		  }
		  else if (streq(nval->name, "DNS")
				   && strcaseeq(nval->value, peer)) {
			if (conf->verbose_flag)
			  log_msg((LOG_TRACE_LEVEL, "\"DNS\" field matches %s", peer));
			ok = 1;
			break;
		  }
		  else if (streq(nval->name, "IP Address")
				   && streq(nval->value, peer)) {
			if (conf->verbose_flag)
			  log_msg((LOG_TRACE_LEVEL,
					   "\"IP Address\" field matches %s", peer));
			ok = 1;
			break;
		  }
		}
	  }
	  if (ok)
		break;
	}
  }
 
  /*
   * Get the commonName, if only for debugging.
   * Only if we were unsuccessful with the subjectAltName above, try to match
   * against the commonName.
   */
  if ((subj = X509_get_subject_name(cert)) != NULL) {
	char common_name[1024];

	if (X509_NAME_get_text_by_NID(subj, NID_commonName, common_name,
								  sizeof(common_name)) > 0) {
	  common_name[1023] = '\0';
	  if (conf->verbose_flag)
		log_msg((LOG_TRACE_LEVEL, "commonName=\"%s\"", common_name));
	}

	if (!ok) {
	  /* Do this only if subjectAltName/dNSName failed. */
	  if ((st = peer_match(conf, common_name, &errmsg)) == 1)
		ok = 1;
	  else if (st == -1)
		log_msg((LOG_ERROR_LEVEL, "regex exec error: %s", errmsg));
	  else if (strcaseeq(common_name, peer)) {
		if (conf->verbose_flag)
		  log_msg((LOG_TRACE_LEVEL, "\"commonName\" field matches %s", peer));
		ok = 1;
	  }
	}
  }

  if (!ok)
	goto err_occurred;

  X509_free(cert);
  if ((err = SSL_get_verify_result(ssl)) == X509_V_OK)
	return(X509_V_OK);

  if (err == X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN
	  && conf->verify_allow_self_signed)
	return(X509_V_OK);

  if (conf->verify_type == SSL_VERIFY_NONE)
	return(X509_V_OK);

  return(err);

 err_occurred:

  if (cert != NULL)
	X509_free(cert);

  return(X509_V_ERR_APPLICATION_VERIFICATION);
}

/*
 * Get data from the application to send to the server.
 * If PTR is not NULL, it points to a buffer of length LEN that will receive
 * the data; return the actual number of bytes put into the buffer.
 * This is the default function, which simply obtains the data from stdin.
 * If PTR is NULL, return 0 if EOF has not been reached, 1 otherwise.
 *
 * XXX input file/stream should be configurable
 */
ssize_t
ssl_get_data(void *ctx, unsigned char *ptr, size_t len)
{
  int n;
  IO_state *state;

  state = (IO_state *) ctx;

  if (ptr == NULL)
	return(state->eof);

  n = read(0, ptr, len);
  if (n == -1) {
	if (errno == EAGAIN)
	  return(0);
	state->eof = 1;
	state->last_errno = errno;
	return(-1);
  }

  if (n == 0)
	state->eof = 1;

  return((ssize_t) n);
}

/*
 * Give data read from the server back to the application.
 * PTR points to the data and LEN is its length.
 * This is the default function, which simply writes the data to stdout.
 * Return 0 if not all the bytes were written, otherwise return LEN.
 *
 * XXX output file/stream should be configurable
 */
ssize_t
ssl_put_data(void *ctx, unsigned char *ptr, size_t len)
{
  unsigned char *p;
  size_t nleft;
  ssize_t st;

  nleft = len;
  p = ptr;
  while (nleft > 0) {
	st = write(1, p, nleft);
	if (st == -1) {
	  /* EAGAIN? */
	  log_err((LOG_ERROR_LEVEL, "write error"));
	  return(0);
	}
	nleft -= st;
	p += st;
  }

  return(nleft == 0 ? (ssize_t) len : 0);
}

int
ssl_init(Ssl_conf *conf)
{
  static int done_init = 0;

  if (done_init)
	return(0);

  if (!SSL_library_init()) {
	log_msg((LOG_ERROR_LEVEL, "OpenSSL initialization failed!"));
	return(-1);
  }

  SSL_load_error_strings();

  if (seed_prng(conf->rand_seed_file) == -1)
	return(-1);

  done_init = 1; 

  return(0);
}

char *
ssl_get_error_messages(void)
{
  char *msg;
  BIO *bp;
  BUF_MEM *bptr;

  bp = BIO_new(BIO_s_mem());
  ERR_print_errors(bp);
  BIO_get_mem_ptr(bp, &bptr);
  if (bptr->length > 0)
	msg = strndup(bptr->data, bptr->length);
  else
	msg = "";
  BIO_set_close(bp, BIO_NOCLOSE);
  BIO_free(bp);

  return(msg);
}

/*
 * Initialize an SSL context from the command line arguments.
 * Return that context, or NULL on error.
 */
SSL_CTX *
ssl_setup_client_ctx(Ssl_conf *conf, char *servername)
{
  SSL_CTX *ctx;
  const SSL_METHOD *m;
 
  if ((m = SSLv23_method()) == NULL || (ctx = SSL_CTX_new(m)) == NULL) {
	log_msg((LOG_ERROR_LEVEL, "error initializing context"));
	return(NULL);
  }

  if (conf->ca_cert_file != NULL || conf->ca_cert_dir != NULL) {
	if (SSL_CTX_load_verify_locations(ctx, conf->ca_cert_file,
									  conf->ca_cert_dir) != 1) {
	  free(ctx);
	  log_msg((LOG_ERROR_LEVEL, "error loading CA file and/or directory"));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	  return(NULL);
	}
  }

  if (conf->use_default_verify_paths) {
	if (SSL_CTX_set_default_verify_paths(ctx) != 1) {
	  free(ctx);
	  log_msg((LOG_ERROR_LEVEL, "error loading default CA paths"));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	  return(NULL);
	}
  }
	
  if (conf->cert_chain_file != NULL) {
	if (SSL_CTX_use_certificate_chain_file(ctx, conf->cert_chain_file) != 1) {
	  free(ctx);
	  log_msg((LOG_ERROR_LEVEL, "error loading certificate chain file: %s",
			   conf->cert_chain_file));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	  return(NULL);
	}
  }

  if (conf->key_file != NULL) {
	if (SSL_CTX_use_PrivateKey_file(ctx, conf->key_file, conf->key_file_type)
		!= 1) {
	  free(ctx);
	  log_msg((LOG_ERROR_LEVEL, "error loading private key from file: %s",
			   conf->key_file));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	  return(NULL);
	}
  }

  /*
   * Disallow SSLv2 and earlier due to security defects.
   * This allows only SSLv3 and TLSv1.
   */
  SSL_CTX_set_options(ctx, SSL_OP_ALL | SSL_OP_NO_SSLv2);

  if (SSL_CTX_set_cipher_list(ctx, conf->cipher_list) != 1) {
	free(ctx);
	log_msg((LOG_ERROR_LEVEL, "error setting cipher list"));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(NULL);
  }

  if ((conf->ssl = SSL_new(ctx)) == NULL) {
	free(ctx);
	log_msg((LOG_ERROR_LEVEL, "error setting cipher list"));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(NULL);
  }

  /*
   * SNI - Server Name Indication (RFC 6066, S3)
   * Testing service at:
   *   https://sni.velox.ch/
   */
#ifdef OPENSSL_NO_TLSEXT
  if (conf->use_sni && servername != NULL)
	log_msg((LOG_WARN_LEVEL, "OpenSSL TLS extensions are unavailable"));
#else
  if (conf->use_sni && servername != NULL) {
	if (SSL_set_tlsext_host_name(conf->ssl, servername) == 0) {
	  log_msg((LOG_ERROR_LEVEL, "error setting SNI servername to \"%s\"",
			   servername));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	}
	else
	  log_msg((LOG_TRACE_LEVEL, "set SNI servername to \"%s\"", servername));
  }
else
  log_msg((LOG_TRACE_LEVEL, "SNI available but not requested"));
#endif

  SSL_CTX_set_verify(ctx, conf->verify_type, ssl_verify_callback);
  if (conf->verify_depth > 0)
	SSL_CTX_set_verify_depth(ctx, conf->verify_depth);

  global_conf.store = SSL_CTX_get_cert_store(ctx);
  global_conf.conf = conf;

  return(ctx);
}

/*
 * Establish an SSL connection to HOSTNAME:PORT, then get data from the
 * application to write to the SSL connection by calling USER_GET_DATA
 * (with USER_GET_ARG passed as the first argument) and return data read
 * from the SSL connection by calling USER_PUT_DATA (with USER_PUT_ARG
 * passed as the first argument).
 * If either function is NULL, the corresponding default function and
 * default argument are used.
 * Return -1 on error, 0 otherwise.
 */
int
ssl_transfer(char *hostname, char *port, Ssl_conf *conf,
			 Ssl_io_callback user_get_data, void *user_get_arg,
			 Ssl_io_callback user_put_data, void *user_put_arg)
{
  char *errmsg, *server;
  long err;
  void *get_arg, *put_arg;
  Ssl_io_callback get_data, put_data;
  BIO *conn;
  SSL *ssl;
  SSL_CTX *ctx;
  IO_state get_state;

  if (hostname == NULL || port == NULL || conf == NULL)
	return(-1);

  if (user_get_data != NULL) {
	get_data = user_get_data;
	get_arg = user_get_arg;
  }
  else {
	get_data = ssl_get_data;
	get_arg = &get_state;
  }

  if (user_put_data != NULL) {
	put_data = user_put_data;
	put_arg = user_put_arg;
  }
  else {
	put_data = ssl_put_data;
	put_arg = NULL;
  }

  ssl_init(conf);
  if ((ctx = ssl_setup_client_ctx(conf, hostname)) == NULL)
	return(-1);
  ssl = conf->ssl;
 
  server = ds_xprintf("%s:%s", hostname, port);
  conn = BIO_new_connect(server);
  if (conn == NULL) {
	log_msg((LOG_ERROR_LEVEL, "error creating connection BIO"));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(-1);
  }
 
  if (BIO_do_connect(conn) <= 0) {
	log_msg((LOG_ERROR_LEVEL, "error connecting to %s", server));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(-1);
  }
 
  SSL_set_bio(ssl, conn, conn);
  if (SSL_connect(ssl) <= 0) {
	log_msg((LOG_ERROR_LEVEL, "SSL_connect() to %s failed", server));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(-1);
  }

  if ((err = ssl_post_connection_check(ssl, hostname, conf)) != X509_V_OK) {
	log_msg((LOG_ERROR_LEVEL, "Peer certificate error: %s",
			 X509_verify_cert_error_string(err)));
	log_msg((LOG_ERROR_LEVEL, "post connection check failed"));
	log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	return(-1);
  }

  if (conf->verbose_flag)
	log_msg((LOG_TRACE_LEVEL, "SSL Connection opened to %s", server));

  if (net_set_nonblocking(0, &errmsg) == -1)
	log_msg((LOG_ERROR_LEVEL, "%s", errmsg));

  get_state.eof = 0;
  get_state.last_errno = 0;
  ssl_transfer_data(ssl, get_data, get_arg, put_data, put_arg);

  if (conf->verbose_flag)
	log_msg((LOG_TRACE_LEVEL, "SSL Connection closed"));
 
  SSL_free(ssl);
  SSL_CTX_free(ctx);

  return(0);
}

/*
 * Check the SSL connection for readiness to read and/or write.
 * CAN_READ will be set non-zero iff we can read from the SSL connection.
 * CAN_WRITE will be set non-zero iff we can write to the SSL connection.
 * Return -1 on error, 0 otherwise.
 */
static int
check_io(SSL *ssl, int *can_read, int *can_write)
{
  int fd, maxfds, rc;
  fd_set rfds, wfds;

  fd = SSL_get_fd(ssl);
  FD_ZERO(&rfds);
  FD_ZERO(&wfds);
  FD_SET(fd, &rfds);
  FD_SET(fd, &wfds);
  maxfds = fd;
  rc = select(maxfds + 1, &rfds, &wfds, NULL, NULL);
  if (rc == -1) {
	perror("ssllib: check_io: select");
	return(-1);
  }

  *can_read = *can_write = 0;
  if (FD_ISSET(fd, &rfds))
	*can_read = 1;
  if (FD_ISSET(fd, &wfds))
	*can_write = 1;

  return(0);
}

/*
 * Write SLEN bytes from STR to SSL.
 * Return the number of bytes written if successful (all SLEN bytes were
 * written), -1 otherwise.
 */
static int
ssl_write(SSL *ssl, unsigned char *str, size_t slen)
{
  int n, nwritten, st;
  size_t len;

  if (slen == 0)
	return(0);

  len = slen;
  nwritten = 0;
  while (1) {
	st = SSL_write(ssl, str, len);

	switch (SSL_get_error(ssl, st)) {
	case SSL_ERROR_NONE:
	  /*
	   * Adjust the length of the buffer to be smaller by the number bytes
	   * written.  If the buffer is empty, write the newline.
	   */
	  nwritten += st;
	  len -= st;
	  if (len == 0)
		return(nwritten);

	  break;

	case SSL_ERROR_ZERO_RETURN:
	  /* Connection closed. */
	  log_msg((LOG_ERROR_LEVEL, "Error zero return, nwritten=%d", nwritten));
	  return(-1);

	case SSL_ERROR_WANT_READ:
	case SSL_ERROR_WANT_WRITE:
	  /* Retry. */
	  break;

	default:
	  /* ERROR */
	  log_msg((LOG_ERROR_LEVEL, "Bad SSL_get_error(), nwritten=%d", nwritten));
	  return(-1);
	}
  }
  /*NOTREACHED*/
}

int
ssl_printf(SSL *ssl, const char *fmt, ...)
{
  Ds ds;
  va_list ap;

  ds_init(&ds);
  ds.exact_flag = 1;

  va_start(ap, fmt);
  if (ds_vasprintf(&ds, fmt, ap) == -1) {
	va_end(ap);
	return(-1);
  }

  va_end(ap);

  if (ssl_write(ssl, (unsigned char *) ds_buf(&ds), ds_len(&ds) - 1) == -1)
	return(-1);

  return(0);
}

/*
 * Write STR to SSL, then write a newline.
 * Return 0 if successful, -1 otherwise.
 */
int
ssl_puts(SSL *ssl, char *str)
{
  int do_nl, st;
  size_t len;

  if (ssl_write(ssl, (unsigned char *) str, strlen(str)) == -1)
	return(-1);

  if (ssl_write(ssl, (unsigned char *) "\n", 1) == -1)
	return(-1);

  return(0);
}

/*
 * Read one byte from SSL into BUFP
 * Return 1 if a character was read, 0 if EOF, or -1 if an error occurs.
 */
static int
ssl_readch(SSL *ssl, int *bufp)
{
  int ch, err, saved_errno, st;
  unsigned long errcode;

  while (1) {
	st = SSL_read(ssl, &ch, 1);

	switch ((err = SSL_get_error(ssl, st))) {
	case SSL_ERROR_NONE:
	  *bufp = ch & 0377;
	  return(1);

	case SSL_ERROR_ZERO_RETURN:
	  /* Connection closed. */
	  return(0);

	case SSL_ERROR_WANT_READ:
	case SSL_ERROR_WANT_WRITE:
		/* Retry. */
		break;

	default:
	  /* Error? */
	  saved_errno = errno;
	  if (err == SSL_ERROR_SYSCALL && st == 0 && saved_errno == 0
		  && ERR_peek_error() == 0) {
		/* It seems that it's just a slightly abnormal EOF... */
		*bufp = ch & 0377;
		return(1);
	  }

	  log_msg((LOG_ERROR_LEVEL, "error return from SSL_get_error(): "));
	  while ((errcode = ERR_get_error()) != 0) {
		char *errmsg;

		if ((errmsg = ERR_error_string(errcode, NULL)) != NULL)
		  log_msg((LOG_ERROR_LEVEL, "%s", errmsg));
	  }

	  if (err == SSL_ERROR_SYSCALL) {
		log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_SYSCALL: ret=%d", st));
		if (st == -1)
		  log_msg((LOG_ERROR_LEVEL, "%s",
				   (saved_errno != 0) ? strerror(saved_errno) : "unknown"));
		else if (st == 0)
		  log_msg((LOG_ERROR_LEVEL, "Invalid EOF received"));
	  }
	  else if (err == SSL_ERROR_SSL)
		log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_SSL"));
	  else if (err == SSL_ERROR_WANT_CONNECT)
		log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_CONNECT"));
	  else if (err == SSL_ERROR_WANT_ACCEPT)
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_ACCEPT"));
	  else if (err == SSL_ERROR_WANT_X509_LOOKUP)
		log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_X509_LOOKUP"));
	  else
		log_msg((LOG_ERROR_LEVEL, "unknown error"));

	err:
	  log_msg((LOG_ERROR_LEVEL, "error(s) occurred"));
	  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));
	  return(-1);
	}
  }
  /*NOTREACHED*/
}

/*
 * Read up to (BUFLEN - 1) characters from SSL into BUFP, until either a
 * newline or EOF is seen.  The newline is not copied.  The buffer is
 * null-terminated.
 * Return 0 if successful, -1 if an error occurs or the input is too long.
 *
 * XXX Dsio should probably be extended to take a
 * function and argument to return an input character and handle EOF and
 * errors...
 */
int
ssl_gets(SSL *ssl, char *bufp, size_t buflen)
{
  int ch, err, st;
  size_t len;
  char *p;

  if (buflen < 1)
	return(-1);

  p = bufp;
  len = 0;

  while (ssl_readch(ssl, &ch) == 1 && ch != '\n') {
	if (++len == buflen)
	  return(-1);
	*p++ = ch;
  }

  if (ch == -1)
	return(-1);

  *p++ = '\0';
  return(0);
}

int
ssl_eof(SSL *ssl)
{

  return(SSL_shutdown(ssl));
}

/*
 * Transfer data between client (B) and server (A), asynchronously.
 * Call DATA_TO_NETWORK with argument TO_CTX to get data from the application
 * to write to the server and use DATA_FROM_NETWORK with argument FROM_CTX
 * to pass data to the application from the server.
 */
void
ssl_transfer_data(SSL *ssl, Ssl_io_callback data_to_network, void *to_ctx,
				  Ssl_io_callback data_from_network, void *from_ctx)
{
  int eof_B2A, err, saved_errno, st;
  int have_data_A2B, have_data_B2A;
  int can_read_A, can_write_A;
  int read_waiton_write_A, read_waiton_read_A;
  int write_waiton_write_A, write_waiton_read_A;
  char *errmsg;
  size_t A2B_len, B2A_len;
  unsigned char A2B[BUF_SIZE], B2A[BUF_SIZE];
  unsigned long errcode;
 
  if (net_set_nonblocking(SSL_get_fd(ssl), &errmsg) == -1) {
	log_msg((LOG_ERROR_LEVEL, "%s", errmsg));
	return;
  }

  SSL_set_mode(ssl,
			   SSL_MODE_ENABLE_PARTIAL_WRITE
			   | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); 

  A2B_len = B2A_len = 0;
  have_data_A2B = have_data_B2A = 0;
  can_read_A = can_write_A = 0;
  read_waiton_write_A = read_waiton_read_A = 0;
  write_waiton_write_A = write_waiton_read_A = 0;
  eof_B2A = 0;

  while (1) {
	if (check_io(ssl, &can_read_A, &can_write_A) == -1) {
	  log_msg((LOG_ERROR_LEVEL, "check_io failed"));
	  goto err;
	}

	if (!have_data_B2A && !eof_B2A) {
	  if ((st = data_to_network(to_ctx, B2A, sizeof(B2A))) == -1) {
		log_msg((LOG_ERROR_LEVEL, "error getting data for network"));
		goto end;
	  }
	  if (st != 0) {
		B2A_len = st;
		have_data_B2A = 1;
	  }
	  else
		eof_B2A = data_to_network(to_ctx, NULL, 0);
	}

	/*
	 * This "if" statement writes data to A.  It will only be entered if
	 * the following conditions are all true:
	 * 1. We're not in the middle of a read on A;
	 * 2. There's data in the A to B buffer; and
	 * 3. Either we need to read to complete a previously blocked write
	 *    and now A is available to read, or we can write to A
	 *    regardless of whether we're blocking for availability to write.
	 */
	if (!(read_waiton_write_A || read_waiton_read_A)
		&& have_data_B2A
		&& (can_write_A || (can_read_A && write_waiton_read_A))) {

	  write_waiton_read_A = 0;
	  write_waiton_write_A = 0;

	  st = SSL_write(ssl, B2A, B2A_len);

	  switch (SSL_get_error(ssl, st)) {
	  case SSL_ERROR_NONE:
		/*
		 * Adjust the length of the B to A
		 * buffer to be smaller by the number bytes written.  If
		 * the buffer is empty, set the "have data" flags to 0,
		 * or else, move the data from the middle of the buffer to the front.
		 */
		B2A_len -= st;
		if (!B2A_len)
		  have_data_B2A = 0;
		else
		  memmove(B2A, B2A + st, B2A_len);
		break;

	  case SSL_ERROR_ZERO_RETURN:
		/* Connection closed. */
		goto end;

	  case SSL_ERROR_WANT_READ:
		/* Retry the write after A is available for reading. */
		write_waiton_read_A = 1;
		break;

	  case SSL_ERROR_WANT_WRITE:
		/* Retry the write after A is available for writing. */
		write_waiton_write_A = 1;
		break;

	  default:
		/* ERROR */
		log_msg((LOG_ERROR_LEVEL, "bad return value from SSL_get_error()"));
		goto err;
	  }
	}

	/*
	 * This "if" statement reads data from A.  It will only be entered if
	 * the following conditions are all true:
	 * 1. We're not in the middle of a write on A;
	 * 2. There's space left in the A to B buffer; and
	 * 3. Either we need to write to complete a previously blocked read
	 *    and now A is available to write, or we can read from A
	 *    regardless of whether we're blocking for availability to read.
	 */
	if (!(write_waiton_read_A || write_waiton_write_A)
		&& (A2B_len != BUF_SIZE)
		&& (can_read_A || (can_write_A && read_waiton_write_A))) {
	  /* Clear the flags since they're set based on the I/O call's return. */
	  read_waiton_read_A = 0;
	  read_waiton_write_A = 0;

	  /* Read into the buffer after the current position. */
	  st = SSL_read(ssl, A2B + A2B_len, BUF_SIZE - A2B_len);

	  switch ((err = SSL_get_error(ssl, st))) {
	  case SSL_ERROR_NONE:
		/* Update the length, making sure the "have data" flag is set. */
		A2B_len += st;

		/* If the buffer has filled, or if we're not buffering, flush it. */
		if (A2B_len == BUF_SIZE || !global_conf.conf->buffer_output) {
		  if (data_from_network(from_ctx, A2B, A2B_len) == 0) {
			log_msg((LOG_ERROR_LEVEL,
					 "error passing network data (SSL_ERROR_NONE)"));
			goto end;
		  }
		  have_data_A2B = 0;
		  A2B_len = 0;
		}
		else
		  have_data_A2B = 1;

		break;

	  case SSL_ERROR_ZERO_RETURN:
		/* Connection closed. */
		A2B_len += st;
		if (A2B_len && data_from_network(from_ctx, A2B, A2B_len) == 0) {
		  log_msg((LOG_ERROR_LEVEL, "error passing network data"));
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_ZERO_RETURN: %lu",
				   (unsigned long) A2B_len));
		  goto end;
		}
		goto end;

	  case SSL_ERROR_WANT_READ:
		/* Retry the read after A is available for reading. */
		read_waiton_read_A = 1;
		break;

	  case SSL_ERROR_WANT_WRITE:
		/* Retry the read after A is available for writing. */
		read_waiton_write_A = 1;
		break;

	  default:
		/* Error? */
		saved_errno = errno;
		if (err == SSL_ERROR_SYSCALL && st == 0 && saved_errno == 0
			&& ERR_peek_error() == 0) {
		  /* It seems that it's just a slightly abnormal EOF... */
		  goto end;
		}

		log_msg((LOG_ERROR_LEVEL, "error return from SSL_get_error(): "));
		while ((errcode = ERR_get_error()) != 0) {
		  char *errmsg;

		  if ((errmsg = ERR_error_string(errcode, NULL)) != NULL)
			log_msg((LOG_ERROR_LEVEL, "%s", errmsg));
		}

		if (err == SSL_ERROR_SYSCALL) {
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_SYSCALL: ret=%d", st));
		  if (st == -1)
			log_msg((LOG_ERROR_LEVEL, "%s",
					 (saved_errno != 0) ? strerror(saved_errno) : "unknown"));
		  else if (st == 0)
			log_msg((LOG_ERROR_LEVEL, "Invalid EOF received"));
		}
		else if (err == SSL_ERROR_SSL)
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_SSL"));
		else if (err == SSL_ERROR_WANT_CONNECT)
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_CONNECT"));
		else if (err == SSL_ERROR_WANT_ACCEPT)
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_ACCEPT"));
		else if (err == SSL_ERROR_WANT_X509_LOOKUP)
		  log_msg((LOG_ERROR_LEVEL, "SSL_ERROR_WANT_X509_LOOKUP"));
		else
		  log_msg((LOG_ERROR_LEVEL, "unknown error"));
		goto err;
	  }
	}
  }

 err:
  log_msg((LOG_ERROR_LEVEL, "error(s) occurred"));
  log_msg((LOG_ERROR_LEVEL, "%s", ssl_get_error_messages()));

 end:
  if (net_set_blocking(SSL_get_fd(ssl), &errmsg) == -1) {
	log_msg((LOG_ERROR_LEVEL, "%s", errmsg));
	exit(1);
  }

  SSL_shutdown(ssl);

  if (have_data_A2B) {
	if (data_from_network(from_ctx, A2B, A2B_len) == 0) {
	  log_msg((LOG_ERROR_LEVEL, "error passing network data (have_data_A2B)"));
	}
  }
}

Ssl_conf *
ssl_init_defaults(Ssl_conf *c)
{
  Ssl_conf *conf;

  if (c == NULL)
	conf = ALLOC(Ssl_conf);
  else
	conf = c;

  conf->ssl = NULL;
  conf->verbose_flag = 0;
  conf->verify_type = SSL_VERIFY_NONE;
  conf->verify_depth = 0;
  conf->verify_allow_self_signed = 0;
  conf->verify_error = X509_V_OK;
  conf->use_default_verify_paths = 0;
  conf->ca_cert_file = CA_CERT_FILE;
  conf->ca_cert_dir = CA_CERT_DIR;
  conf->cert_chain_file = NULL;
  conf->key_file = NULL;
  conf->key_file_type = SSL_FILETYPE_PEM;
  conf->cipher_list = DEFAULT_CIPHER_LIST;
  conf->rand_seed_file = DEFAULT_RAND_SEED_FILE;
  conf->peer_match_vec = NULL;
#ifdef OPENSSL_NO_TLSEXT
  conf->use_sni = 0;
#else
  conf->use_sni = 1;
#endif
  conf->buffer_output = 0;

  return(conf);
}
