package org.sqlite.udf;

import java.io.IOException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.sqlite.Database;
import org.sqlite.jdbc.JdbcConnection;
import org.sqlite.swig.SQLite3Constants;
import org.sqlite.swig.SWIGTYPE_p_sqlite3;
import static org.junit.Assert.*;

/**
 *
 * @author calico
 */
public class ScalarFunctionTest {

    public ScalarFunctionTest() {
    }

    @BeforeClass
    public static void setUpClass() throws Exception {
    }

    @AfterClass
    public static void tearDownClass() throws Exception {
    }

    @Before
    public void setUp() {
    }

    @After
    public void tearDown() {
    }

    private static final String DRIVER_CLASS = "org.sqlite.Driver";
    private static final String DATABASE = System.getProperty("user.dir") + "/test/unittest.db";

    private static Database newDatabase() throws ClassNotFoundException, SQLException {
        Class.forName(DRIVER_CLASS);
        return new Database(DATABASE, null);
    }
    
    private static Connection newConnection(Database db) throws ClassNotFoundException, SQLException {
        return new JdbcConnection(db, null);
    }

    @Test
    public void createFunction() throws ClassNotFoundException, SQLException, IOException {
        final Database db = newDatabase();
        try {
            final Function regexp
                    = new ScalarFunction("REGEXP", 2) {
                            @Override
                            public void xFunc(Context ctx) throws SQLException {
                                System.out.printf("called xFunc('%s', '%s')!\n", ctx.getString(1), ctx.getString(2));
                                
                                Pattern pattern = (Pattern) ctx.getAuxData(1);
                                if (pattern == null) {
                                    System.out.println("pattern is null!");
                                    try {
                                        pattern = Pattern.compile(ctx.getString(1));
                                        ctx.setAuxData(1, pattern);
                                        
                                    } catch (PatternSyntaxException ex) {
                                        ctx.resultError(ex.toString());
                                        return;
                                    }
                                }
                                
                                SWIGTYPE_p_sqlite3 db = ctx.getDbHandle();
                                assertNotNull(db);
                                
                                ctx.result(pattern.matcher(ctx.getString(2)).matches());
                            }
                        };
            assertFalse(regexp.isRegistered());
            db.createFunction(regexp);
            assertTrue(regexp.isRegistered());
            
            final Connection conn = newConnection(db);
            final Statement stmt = conn.createStatement();
            final String sql
                    = "SELECT T.* "
                    + "FROM ("
                            + "SELECT 'first' AS N UNION ALL "
                            + "SELECT 'second' AS N UNION ALL "
                            + "SELECT 'third' AS N UNION ALL "
                            + "SELECT 'fourth' AS N UNION ALL "
                            + "SELECT 'fifth' AS N UNION ALL "
                            + "SELECT 'sixth' AS N"
                        + ") AS T "
                    + "WHERE T.N REGEXP 'f.*' "
                    + "ORDER BY T.N";
            ResultSet rs = stmt.executeQuery(sql);
            assertTrue(rs.next());
            assertEquals("fifth", rs.getString(1));
            assertTrue(rs.next());
            assertEquals("first", rs.getString(1));
            assertTrue(rs.next());
            assertEquals("fourth", rs.getString(1));
            assertFalse(rs.next());
            rs.close();
            
            db.dropFunction(regexp);
            assertFalse(regexp.isRegistered());
            
            stmt.close();
            conn.close();
            
        } finally {
            db.close();
        }
    }

    @Test(expected = java.sql.SQLException.class)
    public void dropFunction() throws ClassNotFoundException, SQLException, IOException {
        final Database db = newDatabase();
        try {
            final Function myFunc
                    = new ScalarFunction("myFunc") {
                            @Override
                            public void xFunc(Context ctx) {
                                System.out.println("called xFunc()!");
                            }
                        };
            assertFalse(myFunc.isRegistered());
            db.createFunction(myFunc);
            assertTrue(myFunc.isRegistered());
            
            final Connection conn = newConnection(db);
            final Statement stmt = conn.createStatement();
            ResultSet rs = stmt.executeQuery("SELECT myFunc('TEST') LIMIT 1");
            assertTrue(rs.next());
            assertNull(rs.getString(1));
            rs.close();
            
            db.dropFunction(myFunc);
            assertFalse(myFunc.isRegistered());

            rs = stmt.executeQuery("SELECT myFunc('TEST') LIMIT 1");
            rs.close();
            
            stmt.close();
            conn.close();
            
        } finally {
            db.close();
        }
    }
    
    @Test(expected = java.sql.SQLException.class)
    public void resultErrorCode() throws ClassNotFoundException, SQLException, IOException {
        final Database db = newDatabase();
        final int errorCode = SQLite3Constants.SQLITE_EMPTY;
        try {
            final Function myFunc
                    = new ScalarFunction("myFunc", 4) {
                            @Override
                            public void xFunc(Context ctx) throws SQLException {
//                                System.out.printf("called xFunc('%s', '%s')!\n", ctx.getString(1), ctx.getString(2));
                                ctx.resultErrorCode(errorCode);
                            }
                        };
            assertFalse(myFunc.isRegistered());
            db.createFunction(myFunc);
            assertTrue(myFunc.isRegistered());
            
            final Connection conn = newConnection(db);
            final Statement stmt = conn.createStatement();
            final String sql
                    = "SELECT myFunc('http://.*', N, 'ftp://.*', N) FROM ("
                        + "SELECT 'first' AS N UNION ALL "
                        + "SELECT 'second' AS N UNION ALL "
                        + "SELECT 'third' AS N UNION ALL "
                        + "SELECT 'fourth' AS N UNION ALL "
                        + "SELECT 'fifth' AS N UNION ALL "
                        + "SELECT 'sixth' AS N"
                    + ")";
            ResultSet rs = stmt.executeQuery(sql);
            assertTrue(rs.next());
            rs.close();
            stmt.close();
            conn.close();
            
        } catch (SQLException ex) {
            assertEquals(errorCode, ex.getErrorCode());
            throw ex;
        
        } finally {
            db.close();
        }
    }
}