381 lines
12 KiB
Java
381 lines
12 KiB
Java
package org.gcube.data.analysis.tabulardata.cube.data;
|
|
|
|
import java.io.BufferedReader;
|
|
import java.io.IOException;
|
|
import java.io.InputStream;
|
|
import java.io.InputStreamReader;
|
|
import java.sql.Connection;
|
|
import java.sql.ResultSet;
|
|
import java.sql.SQLException;
|
|
import java.sql.Statement;
|
|
import java.util.List;
|
|
import java.util.Random;
|
|
import java.util.regex.Pattern;
|
|
|
|
import javax.annotation.PostConstruct;
|
|
import javax.enterprise.inject.Default;
|
|
import javax.inject.Inject;
|
|
import javax.inject.Singleton;
|
|
|
|
import org.apache.commons.dbutils.DbUtils;
|
|
import org.gcube.data.analysis.tabulardata.cube.data.connection.DatabaseConnectionProvider;
|
|
import org.gcube.data.analysis.tabulardata.cube.data.connection.admin.Admin;
|
|
import org.gcube.data.analysis.tabulardata.cube.data.connection.unprivileged.Unprivileged;
|
|
import org.gcube.data.analysis.tabulardata.model.datatype.DataType;
|
|
import org.gcube.data.analysis.tabulardata.model.datatype.value.TDTypeValue;
|
|
import org.gcube.data.analysis.tabulardata.model.mapping.SQLModelMapper;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
@Default
|
|
@Singleton
|
|
public class SQLDatabaseWrangler implements DatabaseWrangler {
|
|
|
|
private final String DEFAULT_SCHEMA_NAME = "public";
|
|
|
|
private static Logger log = LoggerFactory.getLogger(SQLDatabaseWrangler.class);
|
|
|
|
private static RandomString randomString = new RandomString(32);
|
|
|
|
private DatabaseConnectionProvider adminConnectionProvider;
|
|
|
|
private DatabaseConnectionProvider unprivilegedConnectionProvider;
|
|
|
|
private SQLModelMapper sqlModelMapper;
|
|
|
|
private ResourceFinder resourceFinder;
|
|
|
|
@Inject
|
|
public SQLDatabaseWrangler(@Admin DatabaseConnectionProvider adminConnectionProvider,
|
|
@Unprivileged DatabaseConnectionProvider unprivilegedConnectionProvider, SQLModelMapper sqlModelMapper, @Default ResourceFinder resourceFinder) {
|
|
super();
|
|
this.adminConnectionProvider = adminConnectionProvider;
|
|
this.unprivilegedConnectionProvider = unprivilegedConnectionProvider;
|
|
this.sqlModelMapper = sqlModelMapper;
|
|
this.resourceFinder = resourceFinder;
|
|
}
|
|
|
|
@PostConstruct
|
|
private void initializeSql() {
|
|
for (String file :resourceFinder.getResourcesPath(Pattern.compile(".*\\.sql"))){
|
|
BufferedReader reader =null;
|
|
if (!file.contains("org/gcube/data/analysis/tabulardata/sql/"))
|
|
file = "org/gcube/data/analysis/tabulardata/sql/"+file;
|
|
|
|
InputStream is = resourceFinder.getStream(file);
|
|
if (is==null)
|
|
continue;
|
|
|
|
try{
|
|
reader = new BufferedReader(new InputStreamReader(is));
|
|
String line = null;
|
|
StringBuilder stringBuilder = new StringBuilder();
|
|
while( ( line = reader.readLine() ) != null )
|
|
stringBuilder.append( line );
|
|
|
|
executeQuery(stringBuilder.toString());
|
|
}catch(Exception e){
|
|
throw new RuntimeException("error initializing sql",e);
|
|
}finally{
|
|
if (reader!=null )
|
|
try {
|
|
reader.close();
|
|
} catch (IOException e) {
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@Override
|
|
public String createTable() {
|
|
return createTable(false);
|
|
}
|
|
|
|
@Override
|
|
public void createTable(String name){
|
|
createInternal(false, name);
|
|
}
|
|
|
|
@Override
|
|
public String createTable(boolean unsafe) {
|
|
String tableName = generateTableName();
|
|
createInternal(unsafe, tableName);
|
|
return tableName;
|
|
}
|
|
|
|
private void createInternal(boolean unsafe, String tableName){
|
|
String query = generateCreateTableQuery(tableName, unsafe);
|
|
query += generateUserAccountGrantQuery(tableName);
|
|
query += generateUserAccountGrantQuery(tableName + "_id_seq");
|
|
executeQuery(query);
|
|
}
|
|
|
|
private String generateCreateTableQuery(String tableName, boolean unlogged) {
|
|
if (unlogged)
|
|
return String.format("CREATE UNLOGGED TABLE %1$s ( id serial primary key);", tableName);
|
|
return String.format("CREATE TABLE %1$s ( id serial primary key);", tableName);
|
|
}
|
|
|
|
@Override
|
|
public void removeTable(String tableName) {
|
|
String query = generateDropTableQuery(tableName);
|
|
executeQuery(query);
|
|
}
|
|
|
|
private String generateDropTableQuery(String tableName) {
|
|
return String.format("DROP TABLE %1$s;", tableName);
|
|
}
|
|
|
|
@Override
|
|
public String cloneTable(String tableName, boolean withData, boolean unsafe) {
|
|
String newTableName = generateTableName();
|
|
String query = generateCloneTableQuery(newTableName, tableName, withData, unsafe);
|
|
query += generateUserAccountGrantQuery(newTableName);
|
|
query += generateUserAccountGrantQuery(newTableName + "_id_seq");
|
|
executeQuery(query);
|
|
return newTableName;
|
|
}
|
|
|
|
private String generateCloneTableQuery(String newTableName, String tableToCloneName, boolean withData,
|
|
boolean unsafe) {
|
|
StringBuilder sb = new StringBuilder();
|
|
String unlogged = "";
|
|
if (unsafe)
|
|
unlogged = "UNLOGGED";
|
|
String data = "";
|
|
|
|
if (withData)
|
|
data = "WITH DATA";
|
|
else
|
|
data = "WITH NO DATA";
|
|
|
|
sb.append(String.format("CREATE %1$s TABLE %2$s WITHOUT OIDS AS TABLE %3$s %4$s;", unlogged, newTableName, tableToCloneName,
|
|
data));
|
|
sb.append(String.format("CREATE SEQUENCE %1$s_id_seq;", newTableName));
|
|
if (withData)
|
|
sb.append(String.format("SELECT setval('%1$s_id_seq', max(id) ) FROM %2$s;", newTableName, tableToCloneName));
|
|
else
|
|
sb.append(String.format("SELECT setval('%1$s_id_seq', 1 );", newTableName));
|
|
sb.append(String.format("ALTER TABLE %1$s ALTER id SET NOT NULL;", newTableName));
|
|
sb.append(String.format("ALTER TABLE %1$s ALTER id SET DEFAULT nextval('%1$s_id_seq');", newTableName));
|
|
log.debug("executing creation queries: "+sb.toString());
|
|
return sb.toString();
|
|
}
|
|
|
|
@Override
|
|
public boolean exists(String tableName) {
|
|
return executeCount(String.format("SELECT count(*) FROM pg_tables WHERE tablename='%1$s'", tableName.toLowerCase()))>0;
|
|
|
|
}
|
|
|
|
@Override
|
|
public void addColumn(String tableName, String columnName, DataType type) {
|
|
String query = generateAddColumnQuery(tableName, columnName, type, null);
|
|
executeQuery(query);
|
|
}
|
|
|
|
@Override
|
|
public void addColumn(String tableName, String columnName, DataType type, TDTypeValue defaultValue) {
|
|
String query = generateAddColumnQuery(tableName, columnName, type, defaultValue);
|
|
executeQuery(query);
|
|
}
|
|
|
|
private String generateAddColumnQuery(String tableName, String columnName, DataType type,TDTypeValue defaultValue) {
|
|
return String.format("ALTER TABLE %1$s ADD COLUMN %2$s %3$s %4$s", tableName, columnName, getColumnSQLType(type), getDefaultValueSQL(defaultValue));
|
|
}
|
|
|
|
private String getColumnSQLType(DataType type) {
|
|
return sqlModelMapper.translateDataTypeToSQL(type);
|
|
}
|
|
|
|
private String getDefaultValueSQL(TDTypeValue defaultValue) {
|
|
if (defaultValue==null) return "";
|
|
else return String.format("DEFAULT %s", sqlModelMapper.translateModelValueToSQL(defaultValue));
|
|
}
|
|
|
|
@Override
|
|
public void removeColumn(String tableName, String columnName) {
|
|
String query = generateDropColumnQuery(tableName, columnName);
|
|
executeQuery(query);
|
|
}
|
|
|
|
private String generateDropColumnQuery(String tableName, String columnName) {
|
|
return String.format("ALTER TABLE %1$s DROP COLUMN %2$s;", tableName, columnName);
|
|
}
|
|
|
|
@Override
|
|
public void alterColumnType(String tableName, String columnName,
|
|
DataType newType) {
|
|
String query = generateAlterTypeQuery(tableName, columnName, newType);
|
|
executeQuery(query);
|
|
}
|
|
|
|
private String generateAlterTypeQuery(String tableName, String columnName, DataType type) {
|
|
return String.format("ALTER TABLE %1$s ALTER COLUMN %2$s TYPE %3$s;", tableName, columnName, getColumnSQLType(type));
|
|
}
|
|
|
|
private String generateTableName() {
|
|
String tableName = null;
|
|
int count = 0;
|
|
|
|
do {
|
|
tableName = randomString.nextString().toLowerCase();
|
|
log.debug("Generated table name: " + tableName);
|
|
Connection connection = null;
|
|
Statement statement = null;
|
|
try {
|
|
connection = adminConnectionProvider.getConnection();
|
|
statement = connection.createStatement();
|
|
statement.execute(String.format(
|
|
"SELECT * FROM pg_tables WHERE schemaname='%1$s' AND tablename='%2$s';", DEFAULT_SCHEMA_NAME,
|
|
tableName));
|
|
count = statement.getFetchSize();
|
|
log.debug(String.format("Table with name '%1$s' found %2$s times.", tableName, count));
|
|
} catch (SQLException e) {
|
|
log.error("Error occurred while verifying generated table name.", e);
|
|
throw new RuntimeException("Unable to generate a table name", e);
|
|
} finally {
|
|
DbUtils.closeQuietly(connection);
|
|
DbUtils.closeQuietly(statement);
|
|
}
|
|
|
|
} while (count > 0);
|
|
|
|
return tableName;
|
|
}
|
|
|
|
public void executeQuery(String query) {
|
|
log.debug("Executing SQL query: " + query);
|
|
Connection connection = null;
|
|
Statement statement = null;
|
|
try {
|
|
connection = adminConnectionProvider.getConnection();
|
|
statement = connection.createStatement();
|
|
statement.execute(query + ";");
|
|
connection.close();
|
|
} catch (SQLException e) {
|
|
log.error("Unable to execute query: " + query, e);
|
|
throw new RuntimeException("Error encountered while executing database query: " + query,e);
|
|
} finally {
|
|
DbUtils.closeQuietly(connection);
|
|
DbUtils.closeQuietly(statement);
|
|
}
|
|
}
|
|
|
|
private int executeCount(String query) {
|
|
log.debug("Executing SQL query: " + query);
|
|
Connection connection = null;
|
|
Statement statement = null;
|
|
try {
|
|
connection = adminConnectionProvider.getConnection();
|
|
statement = connection.createStatement();
|
|
ResultSet ret = statement.executeQuery(query + ";");
|
|
int toReturn = 0;
|
|
if (ret.next())
|
|
toReturn = ret.getInt(1);
|
|
connection.close();
|
|
return toReturn;
|
|
} catch (SQLException e) {
|
|
log.error("Unable to execute query: " + query, e);
|
|
throw new RuntimeException("Error encountered while executing database query: " + query,e);
|
|
} finally {
|
|
DbUtils.closeQuietly(connection);
|
|
DbUtils.closeQuietly(statement);
|
|
}
|
|
}
|
|
|
|
private String generateUserAccountGrantQuery(String tableName) {
|
|
String unprivilegedUser = unprivilegedConnectionProvider.getDatabaseEndpoint().getCredentials().getUsername();
|
|
return String.format("GRANT SELECT,UPDATE,INSERT ON TABLE %1$s TO %2$s;", tableName, unprivilegedUser);
|
|
}
|
|
|
|
@Override
|
|
public void createIndex(String tableName, String columnName) {
|
|
Thread t = new Thread(new IndexCreator(tableName, columnName));
|
|
t.start();
|
|
}
|
|
|
|
private class IndexCreator implements Runnable {
|
|
|
|
private String tableName;
|
|
private String columnName;
|
|
|
|
public IndexCreator(String tableName, String columnName) {
|
|
this.tableName = tableName;
|
|
this.columnName = columnName;
|
|
}
|
|
|
|
@Override
|
|
public void run() {
|
|
String query = String.format("CREATE INDEX ON %1$s ( %2$s );", tableName, columnName);
|
|
executeQuery(query);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void setNullable(String tableName, String columnName, boolean nullable) {
|
|
String notNullSnippet = "SET NOT NULL";
|
|
String nullableSnippet = "DROP NOT NULL";
|
|
executeQuery(String.format("ALTER TABLE %s ALTER COLUMN %s %s;", tableName, columnName,
|
|
nullable ? nullableSnippet : notNullSnippet));
|
|
}
|
|
|
|
@Override
|
|
public void createTriggerOnTable(String triggerName, List<Condition> conditions, HTime htime, String targetTableName, String procedure) {
|
|
|
|
if (conditions.isEmpty()) throw new IllegalArgumentException("at least a condition has to be set");
|
|
StringBuilder sBuilder = new StringBuilder();
|
|
for (Condition cond : conditions)
|
|
sBuilder.append(cond.name()).append(" OR ");
|
|
|
|
String conds = sBuilder.delete(sBuilder.length()-4, sBuilder.length()).toString();
|
|
|
|
executeQuery(String.format(
|
|
"CREATE TRIGGER %s %s %s ON %s FOR EACH ROW EXECUTE PROCEDURE %s;", triggerName, htime.name(), conds,
|
|
targetTableName, procedure));
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
public void createUniqueIndex(String tableName, String columnName) {
|
|
executeQuery(String.format("CREATE UNIQUE INDEX ON %s ( %s );", tableName, columnName));
|
|
}
|
|
|
|
public static class RandomString {
|
|
|
|
private static final char[] symbols;
|
|
private static final char[] startingSymbols;
|
|
|
|
static {
|
|
StringBuilder tmp = new StringBuilder();
|
|
for (char ch = 'a'; ch <= 'z'; ++ch)
|
|
tmp.append(ch);
|
|
startingSymbols = tmp.toString().toCharArray();
|
|
for (char ch = '0'; ch <= '9'; ++ch)
|
|
tmp.append(ch);
|
|
symbols = tmp.toString().toCharArray();
|
|
}
|
|
|
|
private final int length;
|
|
|
|
private final Random random = new Random();
|
|
|
|
public RandomString(int length) {
|
|
if (length < 1)
|
|
throw new IllegalArgumentException("length < 1: " + length);
|
|
this.length = length;
|
|
}
|
|
|
|
public synchronized String nextString() {
|
|
char[] buf = new char[this.length];
|
|
buf[0] = startingSymbols[random.nextInt(startingSymbols.length)];
|
|
for (int idx = 1; idx < buf.length; ++idx)
|
|
buf[idx] = symbols[random.nextInt(symbols.length)];
|
|
return new String(buf);
|
|
}
|
|
}
|
|
}
|