Spaces:
Runtime error
Runtime error
| import psycopg2 | |
| import pymysql | |
| import json | |
| import logging | |
| import os | |
| from enum import IntEnum | |
| class DataType(IntEnum): | |
| VALUE = 0 | |
| TIME = 1 | |
| CHAR = 2 | |
| AGGREGATE_CONSTRAINTS = { | |
| DataType.VALUE.value: ["count", "max", "min", "avg", "sum"], | |
| DataType.VALUE.CHAR: ["count", "max", "min"], | |
| DataType.VALUE.TIME: ["count", "max", "min"], | |
| } | |
| def transfer_field_type(database_type, server): | |
| data_type = list() | |
| if server == "mysql": | |
| data_type = [ | |
| [ | |
| "int", | |
| "tinyint", | |
| "smallint", | |
| "mediumint", | |
| "bigint", | |
| "float", | |
| "double", | |
| "decimal", | |
| ], | |
| ["date", "time", "year", "datetime", "timestamp"], | |
| ] | |
| database_type = database_type.lower().split("(")[0] | |
| elif server == "postgresql": | |
| data_type = [["integer", "numeric"], ["date"]] | |
| if database_type in data_type[0]: | |
| return DataType.VALUE.value | |
| elif database_type in data_type[1]: | |
| return DataType.TIME.value | |
| else: | |
| return DataType.CHAR.value | |
| class DBArgs(object): | |
| def __init__(self, dbtype, config, dbname=None): | |
| self.dbtype = dbtype | |
| if self.dbtype == "mysql": | |
| self.host = config["host"] | |
| self.port = config["port"] | |
| self.user = config["user"] | |
| self.password = config["password"] | |
| self.dbname = dbname if dbname else config["dbname"] | |
| self.driver = "com.mysql.jdbc.Driver" | |
| self.jdbc = "jdbc:mysql://" | |
| else: | |
| self.host = config["host"] | |
| self.port = config["port"] | |
| self.user = config["user"] | |
| self.password = config["password"] | |
| self.dbname = dbname if dbname else config["dbname"] | |
| self.driver = "org.postgresql.Driver" | |
| self.jdbc = "jdbc:postgresql://" | |
| class Database: | |
| def __init__(self, args, timeout=-1): | |
| self.args = args | |
| self.conn = self.resetConn(timeout) | |
| # self.schema = self.compute_table_schema() | |
| def resetConn(self, timeout=-1): | |
| if self.args.dbtype == "mysql": | |
| conn = pymysql.connect( | |
| host=self.args.host, | |
| user=self.args.user, | |
| passwd=self.args.password, | |
| database=self.args.dbname, | |
| port=int(self.args.port), | |
| charset="utf8", | |
| connect_timeout=timeout, | |
| read_timeout=timeout, | |
| write_timeout=timeout, | |
| ) | |
| else: | |
| if timeout > 0: | |
| conn = psycopg2.connect( | |
| database=self.args.dbname, | |
| user=self.args.user, | |
| password=self.args.password, | |
| host=self.args.host, | |
| port=self.args.port, | |
| options="-c statement_timeout={}s".format(timeout), | |
| ) | |
| else: | |
| conn = psycopg2.connect( | |
| database=self.args.dbname, | |
| user=self.args.user, | |
| password=self.args.password, | |
| host=self.args.host, | |
| port=self.args.port, | |
| ) | |
| return conn | |
| """ | |
| def exec_fetch(self, statement, one=True): | |
| cur = self.conn.cursor() | |
| cur.execute(statement) | |
| if one: | |
| return cur.fetchone() | |
| return cur.fetchall() | |
| """ | |
| def execute_sql(self, sql): | |
| fail = 1 | |
| self.conn = self.resetConn() | |
| cur = self.conn.cursor() | |
| i = 0 | |
| cnt = 3 # retry times | |
| while fail == 1 and i < cnt: | |
| try: | |
| fail = 0 | |
| cur.execute(sql) | |
| except BaseException: | |
| fail = 1 | |
| res = [] | |
| if fail == 0: | |
| res = cur.fetchall() | |
| i = i + 1 | |
| logging.debug( | |
| "database {}, return flag {}, execute sql {}\n".format( | |
| self.args.dbname, 1 - fail, sql | |
| ) | |
| ) | |
| if fail == 1: | |
| # raise RuntimeError("Database query failed") | |
| print("SQL Execution Fatal!!") | |
| return 0, "" | |
| elif fail == 0: | |
| # print("SQL Execution Succeed!!") | |
| return 1, res | |
| def pgsql_results(self, sql): | |
| try: | |
| # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) | |
| success, res = self.execute_sql(sql) | |
| # print("pgsql_results", success, res) | |
| if success == 1: | |
| return res | |
| else: | |
| return "<fail>" | |
| except Exception as error: | |
| logging.error("pgsql_results Exception", error) | |
| return "<fail>" | |
| def pgsql_cost_estimation(self, sql): | |
| try: | |
| # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) | |
| success, res = self.execute_sql("explain (FORMAT JSON) " + sql) | |
| if success == 1: | |
| cost = res[0][0][0]["Plan"]["Total Cost"] | |
| return cost | |
| else: | |
| logging.error("pgsql_cost_estimation Fails!") | |
| return 0 | |
| except Exception as error: | |
| logging.error("pgsql_cost_estimation Exception", error) | |
| return 0 | |
| def pgsql_actual_time(self, sql): | |
| try: | |
| # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) | |
| success, res = self.execute_sql("explain (FORMAT JSON, analyze) " + sql) | |
| if success == 1: | |
| cost = res[0][0][0]["Plan"]["Actual Total Time"] | |
| return cost | |
| else: | |
| return -1 | |
| except Exception as error: | |
| logging.error("pgsql_actual_time Exception", error) | |
| return -1 | |
| def mysql_cost_estimation(self, sql): | |
| try: | |
| success, res = self.execute_sql("explain format=json " + sql) | |
| if success == 1: | |
| total_cost = self.get_mysql_total_cost(0, json.loads(res[0][0])) | |
| return float(total_cost) | |
| else: | |
| return -1 | |
| except Exception as error: | |
| logging.error("mysql_cost_estimation Exception", error) | |
| return -1 | |
| def get_mysql_total_cost(self, total_cost, res): | |
| if isinstance(res, dict): | |
| if "query_cost" in res.keys(): | |
| total_cost += float(res["query_cost"]) | |
| else: | |
| for key in res: | |
| total_cost += self.get_mysql_total_cost(0, res[key]) | |
| elif isinstance(res, list): | |
| for i in res: | |
| total_cost += self.get_mysql_total_cost(0, i) | |
| return total_cost | |
| def get_tables(self): | |
| if self.args.dbtype == "mysql": | |
| return self.mysql_get_tables() | |
| else: | |
| return self.pgsql_get_tables() | |
| # query cost estimated by the optimizer | |
| def cost_estimation(self, sql): | |
| if self.args.dbtype == "mysql": | |
| return self.mysql_cost_estimation(sql) | |
| else: | |
| return self.pgsql_cost_estimation(sql) | |
| def compute_table_schema(self): | |
| """ | |
| schema: {table_name: [field_name]} | |
| :param cursor: | |
| :return: | |
| """ | |
| if self.args.dbtype == "postgresql": | |
| # cur_path = os.path.abspath('.') | |
| # tpath = cur_path + '/sampled_data/'+dbname+'/schema' | |
| sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" | |
| success, res = self.execute_sql(sql) | |
| # print("======== tables", res) | |
| if success == 1: | |
| tables = res | |
| schema = {} | |
| for table_info in tables: | |
| table_name = table_info[0] | |
| sql = ( | |
| "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '" | |
| + table_name | |
| + "';" | |
| ) | |
| success, res = self.execute_sql(sql) | |
| # print("======== table columns", res) | |
| columns = res | |
| schema[table_name] = [] | |
| for col in columns: | |
| """compute the distinct value ratio of the column | |
| if transfer_field_type(col[1], self.args.dbtype) == DataType.VALUE.value: | |
| sql = 'SELECT count({}) FROM {};'.format(col[0], table_name) | |
| success, res = self.execute_sql(sql) | |
| print("======== column rows", res) | |
| num = res | |
| if num[0][0] != 0: | |
| schema[table_name].append(col[0]) | |
| """ | |
| # schema[table_name].append("column {} is of {} type".format(col[0], col[1])) | |
| schema[table_name].append("{}".format(col[0])) | |
| """ | |
| with open(tpath, 'w') as f: | |
| f.write(str(schema)) | |
| """ | |
| # print(schema) | |
| return schema | |
| else: | |
| logging.error("pgsql_cost_estimation Fails!") | |
| return 0 | |
| def simulate_index(self, index): | |
| # table_name = index.table() | |
| statement = "SELECT * FROM hypopg_create_index(E'{}');".format(index) | |
| result = self.execute_sql(statement) | |
| return result | |
| def drop_simulated_index(self, oid): | |
| statement = f"select * from hypopg_drop_index({oid})" | |
| result = self.execute_sql(statement) | |
| assert result[0] is True, f"Could not drop simulated index with oid = {oid}." | |