package org.jruby.puma;

import org.jruby.Ruby; import org.jruby.RubyClass; import org.jruby.RubyModule; import org.jruby.RubyObject; import org.jruby.RubyString; import org.jruby.anno.JRubyMethod; import org.jruby.runtime.Block; import org.jruby.runtime.ObjectAllocator; import org.jruby.runtime.ThreadContext; import org.jruby.runtime.builtin.IRubyObject; import org.jruby.util.ByteList;

import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException;

import static javax.net.ssl.SSLEngineResult.Status; import static javax.net.ssl.SSLEngineResult.HandshakeStatus;

public class MiniSSL extends RubyObject {

private static ObjectAllocator ALLOCATOR = new ObjectAllocator() {
  public IRubyObject allocate(Ruby runtime, RubyClass klass) {
    return new MiniSSL(runtime, klass);
  }
};

public static void createMiniSSL(Ruby runtime) {
  RubyModule mPuma = runtime.defineModule("Puma");
  RubyModule ssl = mPuma.defineModuleUnder("MiniSSL");

  mPuma.defineClassUnder("SSLError",
                         runtime.getClass("IOError"),
                         runtime.getClass("IOError").getAllocator());

  RubyClass eng = ssl.defineClassUnder("Engine",runtime.getObject(),ALLOCATOR);
  eng.defineAnnotatedMethods(MiniSSL.class);
}

   Fairly transparent wrapper around {@link java.nio.ByteBuffer} which adds the enhancements we need

private static class MiniSSLBuffer {
  ByteBuffer buffer;

  private MiniSSLBuffer(int capacity) { buffer = ByteBuffer.allocate(capacity); }
  private MiniSSLBuffer(byte[] initialContents) { buffer = ByteBuffer.wrap(initialContents); }

  public void clear() { buffer.clear(); }
  public void compact() { buffer.compact(); }
  public void flip() { buffer.flip(); }
  public boolean hasRemaining() { return buffer.hasRemaining(); }
  public int position() { return buffer.position(); }

  public ByteBuffer getRawBuffer() {
    return buffer;
  }

  /**
     Writes bytes to the buffer after ensuring there's room
    /
  public void put(byte[] bytes) {
    if (buffer.remaining() < bytes.length) {
      resize(buffer.limit() + bytes.length);
    }
    buffer.put(bytes);
  }

  /**
     Ensures that newCapacity bytes can be written to this buffer, only re-allocating if necessary
    /
  public void resize(int newCapacity) {
    if (newCapacity > buffer.capacity()) {
      ByteBuffer dstTmp = ByteBuffer.allocate(newCapacity);
      buffer.flip();
      dstTmp.put(buffer);
      buffer = dstTmp;
    } else {
      buffer.limit(newCapacity);
    }
  }

  /**
     Drains the buffer to a ByteList, or returns null for an empty buffer
    /
  public ByteList asByteList() {
    buffer.flip();
    if (!buffer.hasRemaining()) {
      buffer.clear();
      return null;
    }

    byte[] bss = new byte[buffer.limit()];

    buffer.get(bss);
    buffer.clear();
    return new ByteList(bss);
  }

  @Override
  public String toString() { return buffer.toString(); }
}

private SSLEngine engine;
private MiniSSLBuffer inboundNetData;
private MiniSSLBuffer outboundAppData;
private MiniSSLBuffer outboundNetData;

public MiniSSL(Ruby runtime, RubyClass klass) {
  super(runtime, klass);
}

@JRubyMethod(meta = true)
public static IRubyObject server(ThreadContext context, IRubyObject recv, IRubyObject miniSSLContext) {
  RubyClass klass = (RubyClass) recv;

  return klass.newInstance(context,
      new IRubyObject[] { miniSSLContext },
      Block.NULL_BLOCK);
}

@JRubyMethod
public IRubyObject initialize(ThreadContext threadContext, IRubyObject miniSSLContext)
    throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException, KeyManagementException {
  KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
  KeyStore ts = KeyStore.getInstance(KeyStore.getDefaultType());

  char[] password = miniSSLContext.callMethod(threadContext, "keystore_pass").convertToString().asJavaString().toCharArray();
  String keystoreFile = miniSSLContext.callMethod(threadContext, "keystore").convertToString().asJavaString();
  ks.load(new FileInputStream(keystoreFile), password);
  ts.load(new FileInputStream(keystoreFile), password);

  KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
  kmf.init(ks, password);

  TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
  tmf.init(ts);

  SSLContext sslCtx = SSLContext.getInstance("TLS");

  sslCtx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
  engine = sslCtx.createSSLEngine();

  String[] protocols = new String[] { "TLSv1", "TLSv1.1", "TLSv1.2" };
  engine.setEnabledProtocols(protocols);
  engine.setUseClientMode(false);

  long verify_mode = miniSSLContext.callMethod(threadContext, "verify_mode").convertToInteger().getLongValue();
  if ((verify_mode & 0x1) != 0) { // 'peer'
      engine.setWantClientAuth(true);
  }
  if ((verify_mode & 0x2) != 0) { // 'force_peer'
      engine.setNeedClientAuth(true);
  }

  SSLSession session = engine.getSession();
  inboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());
  outboundAppData = new MiniSSLBuffer(session.getApplicationBufferSize());
  outboundAppData.flip();
  outboundNetData = new MiniSSLBuffer(session.getPacketBufferSize());

  return this;
}

@JRubyMethod
public IRubyObject inject(IRubyObject arg) {
  try {
    byte[] bytes = arg.convertToString().getBytes();
    inboundNetData.put(bytes);
    return this;
  } catch (Exception e) {
    e.printStackTrace();
    throw new RuntimeException(e);
  }
}

private enum SSLOperation {
  WRAP,
  UNWRAP
}

private SSLEngineResult doOp(SSLOperation sslOp, MiniSSLBuffer src, MiniSSLBuffer dst) throws SSLException {
  SSLEngineResult res = null;
  boolean retryOp = true;
  while (retryOp) {
    switch (sslOp) {
      case WRAP:
        res = engine.wrap(src.getRawBuffer(), dst.getRawBuffer());
        break;
      case UNWRAP:
        res = engine.unwrap(src.getRawBuffer(), dst.getRawBuffer());
        break;
      default:
        throw new IllegalStateException("Unknown SSLOperation: " + sslOp);
    }

    switch (res.getStatus()) {
      case BUFFER_OVERFLOW:
        // increase the buffer size to accommodate the overflowing data
        int newSize = Math.max(engine.getSession().getPacketBufferSize(), engine.getSession().getApplicationBufferSize());
        dst.resize(newSize + dst.position());
        // retry the operation
        retryOp = true;
        break;
      case BUFFER_UNDERFLOW:
        // need to wait for more data to come in before we retry
        retryOp = false;
        break;
      default:
        // other cases are OK and CLOSED.  We're done here.
        retryOp = false;
    }
  }

  // after each op, run any delegated tasks if needed
  if(engine.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
    Runnable runnable;
    while ((runnable = engine.getDelegatedTask()) != null) {
      runnable.run();
    }
  }

  return res;
}

@JRubyMethod
public IRubyObject read() throws Exception {
  try {
    inboundNetData.flip();

    if(!inboundNetData.hasRemaining()) {
      return getRuntime().getNil();
    }

    MiniSSLBuffer inboundAppData = new MiniSSLBuffer(engine.getSession().getApplicationBufferSize());
    doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);

    HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
    boolean done = false;
    while (!done) {
      switch (handshakeStatus) {
        case NEED_WRAP:
          doOp(SSLOperation.WRAP, inboundAppData, outboundNetData);
          break;
        case NEED_UNWRAP:
          SSLEngineResult res = doOp(SSLOperation.UNWRAP, inboundNetData, inboundAppData);
          if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
            // need more data before we can shake more hands
            done = true;
          }
          break;
        default:
          done = true;
      }
      handshakeStatus = engine.getHandshakeStatus();
    }

    if (inboundNetData.hasRemaining()) {
      inboundNetData.compact();
    } else {
      inboundNetData.clear();
    }

    ByteList appDataByteList = inboundAppData.asByteList();
    if (appDataByteList == null) {
      return getRuntime().getNil();
    }

    RubyString str = getRuntime().newString("");
    str.setValue(appDataByteList);
    return str;
  } catch (Exception e) {
    throw getRuntime().newEOFError(e.getMessage());
  }
}

@JRubyMethod
public IRubyObject write(IRubyObject arg) {
  try {
    byte[] bls = arg.convertToString().getBytes();
    outboundAppData = new MiniSSLBuffer(bls);

    return getRuntime().newFixnum(bls.length);
  } catch (Exception e) {
    e.printStackTrace();
    throw new RuntimeException(e);
  }
}

@JRubyMethod
public IRubyObject extract() throws SSLException {
  try {
    ByteList dataByteList = outboundNetData.asByteList();
    if (dataByteList != null) {
      RubyString str = getRuntime().newString("");
      str.setValue(dataByteList);
      return str;
    }

    if (!outboundAppData.hasRemaining()) {
      return getRuntime().getNil();
    }

    outboundNetData.clear();
    doOp(SSLOperation.WRAP, outboundAppData, outboundNetData);
    dataByteList = outboundNetData.asByteList();
    if (dataByteList == null) {
      return getRuntime().getNil();
    }

    RubyString str = getRuntime().newString("");
    str.setValue(dataByteList);

    return str;
  } catch (Exception e) {
    e.printStackTrace();
    throw new RuntimeException(e);
  }
}

@JRubyMethod
public IRubyObject peercert() {
  return getRuntime().getNil();
}

}