package org.monazilla.v2c;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.security.SecureRandom;
import java.util.Hashtable;

import org.bouncycastle.tls.DefaultTlsClient;
import org.bouncycastle.tls.ExtensionType;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.TlsAuthentication;
import org.bouncycastle.tls.TlsExtensionsUtils;
import org.bouncycastle.tls.TlsSession;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCrypto;
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto;

/**
 * TLS socket factory powered by bouncyCastle
 * @author koji.hayakawa
 * https://qiita.com/a__i__r/items/b75a381bf46a863b1139
 * https://github.com/a--i--r/TLSSocketFactory
 *
 */
public class V2CTLSClient extends DefaultTlsClient {

	protected String host = "";
	protected int port;
	protected V2CTLSSocket tlsSocket;
	protected V2CTLSAuthentication tlsAuthentication;

	/**
	 * コンストラクタ
	 * @param host
	 * @param port
	 */
	public V2CTLSClient(String host, int port, V2CTLSSocket sock) {
		super((TlsCrypto) new BcTlsCrypto(new SecureRandom()));
		this.host = host;
		this.port = port;
		this.tlsSocket = sock;
		this.tlsAuthentication = new V2CTLSAuthentication(this);
	}

	public V2CTLSSocket getTlsSocket() {
		return tlsSocket;
	}

	/**
	 * host を取得します
	 * @return host
	 */
	public String getHost() {
		return host;
	}
	/**
	 * host を設定します
	 * @param host
	 */
	public void setHost(String host) {
		this.host = host;
	}
	public int getPort() {
		return port;
	}

	public void setPort(int port) {
		this.port = port;
	}

	/**
	 * session を取得します
	 * @return session
	 */
	public TlsSession getSession() {
		return context.getResumableSession();
	}

	public boolean isTLSv12() {
		return TlsUtils.isTLSv12(context);
	}

	public boolean isTLSv13() {
		return TlsUtils.isTLSv13(context);
	}

	public int getSelectedCipherSuite() {
		return context.getSecurityParameters().getCipherSuite();
	}

	public ProtocolVersion getProtocol() {
		return context.getServerVersion();
	}

	@Override
	public Hashtable<Integer, byte[]> getClientExtensions() throws IOException {

		Hashtable<Integer, byte[]> clientExtensions = super.getClientExtensions();
		if (clientExtensions == null) {
			clientExtensions = new Hashtable<Integer, byte[]>();
		}
		// add hostname
		byte[] hostname = host.getBytes("UTF-8");

		ByteArrayOutputStream baos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(baos);
		dos.writeShort(hostname.length+3); // entry size
		dos.writeByte(0); // name type = hostname
		dos.writeShort(hostname.length);
		dos.write(hostname);
		dos.close();

		clientExtensions.put(ExtensionType.server_name, baos.toByteArray());
		return clientExtensions;
	}

	@Override
	protected boolean allowUnexpectedServerExtension(Integer extensionType, byte[] extensionData) throws IOException {

		switch (extensionType.intValue()) {
		case ExtensionType.ec_point_formats:
			/*
			 * Exception added based on field reports that some servers send Supported
			 * Point Format Extension even when not negotiating an ECC cipher suite.
			 * If present, we still require that it is a valid ECPointFormatList.
			 */
			TlsExtensionsUtils.readSupportedPointFormatsExtension(extensionData);
			return true;
		default:
			return super.allowUnexpectedServerExtension(extensionType, extensionData);
		}
	}

	public TlsAuthentication getAuthentication() throws IOException {

		return this.tlsAuthentication;
	}

}
