diff --git a/stats.py b/stats.py index cb93504..d56d675 100644 --- a/stats.py +++ b/stats.py @@ -77,19 +77,20 @@ class StatsRunner(object): def update_user_ids(self, user_dict: Dict[int, Tuple[str, str]]): for uid in user_dict: username, display_name = user_dict[uid] - query = f""" + sql_dict = {'uid': uid, 'username': username, 'display_name': display_name} + query = """ UPDATE user_names - SET username = '{username}' - WHERE user_id = {uid} AND username IS DISTINCT FROM '{username}'; + SET username = %(username)s + WHERE user_id = %(uid)s AND username IS DISTINCT FROM %(username)s; """ if display_name: - query += f"""\n + query += """\n INSERT INTO user_names(user_id, date, username, display_name) - VALUES ({uid}, current_timestamp, '{username}', '{display_name}'); + VALUES (%(uid)s, current_timestamp, %(username)s, %(display_name)s); """ with self.engine.connect() as con: - con.execute(query) + con.execute(query, sql_dict) def get_chat_counts(self, n: int = None, start: str = None, end: str = None) -> Tuple[str, None]: """ @@ -100,6 +101,8 @@ class StatsRunner(object): :return: """ date_query = None + sql_dict = {} + query_conditions = [] if n is not None: if n <= 0: @@ -108,25 +111,26 @@ class StatsRunner(object): n = 20 if start: - start_dt = pd.to_datetime(start) - date_query = f"WHERE date >= '{start_dt}'" + sql_dict['start_dt'] = pd.to_datetime(start) + query_conditions.append("date >= %(start_dt)s") if end: - end_dt = pd.to_datetime(end) - if date_query: - date_query += f" AND date < '{end_dt}'" - else: - date_query = f"WHERE date < '{end_dt}'" + sql_dict['end_dt'] = pd.to_datetime(end) + query_conditions.append("date < %(end_dt)s") + + query_where = "" + if query_conditions: + query_where = f"WHERE {' AND '.join(query_conditions)}" query = f""" SELECT "from_user", COUNT(*) FROM "messages_utc" - {date_query} + {query_where} GROUP BY "from_user" ORDER BY "count" DESC; """ with self.engine.connect() as con: - df = pd.read_sql_query(query, con, index_col='from_user') + df = pd.read_sql_query(query, con, params=sql_dict, index_col='from_user') user_df = pd.Series(self.users, name="user") user_df = user_df.apply(lambda x: x[0]) # Take only @usernames @@ -146,17 +150,19 @@ class StatsRunner(object): :param end: End timestamp (e.g. 2019, 2019-01, 2019-01-01, "2019-01-01 14:21") """ query_conditions = [] + sql_dict = {} if start: - start_dt = pd.to_datetime(start) - query_conditions.append(f"date >= '{start_dt}'") + sql_dict['start_dt'] = pd.to_datetime(start) + query_conditions.append("date >= %(start_dt)s") if end: - end_dt = pd.to_datetime(end) - query_conditions.append(f"date < '{end_dt}'") + sql_dict['end_dt'] = pd.to_datetime(end) + query_conditions.append("date < %(end_dt)s") if user: - query_conditions.append(f"from_user = {user[0]}") + sql_dict['user'] = user[0] + query_conditions.append("from_user = %(user)s") query_where = "" if query_conditions: @@ -171,7 +177,7 @@ class StatsRunner(object): """ with self.engine.connect() as con: - df = pd.read_sql_query(query, con) + df = pd.read_sql_query(query, con, params=sql_dict) df['day'] = pd.to_datetime(df.day) df['day'] = df.day.dt.tz_convert(self.tz) @@ -221,17 +227,19 @@ class StatsRunner(object): :param plot: Type of plot. ('box' or 'violin') """ query_conditions = [] + sql_dict = {} if start: - start_dt = pd.to_datetime(start) - query_conditions.append(f"date >= '{start_dt}'") + sql_dict['start_dt'] = pd.to_datetime(start) + query_conditions.append("date >= %(start_dt)s") if end: - end_dt = pd.to_datetime(end) - query_conditions.append(f"date < '{end_dt}'") + sql_dict['end_dt'] = pd.to_datetime(end) + query_conditions.append("date < %(end_dt)s") if user: - query_conditions.append(f"from_user = {user[0]}") + sql_dict['user'] = user[0] + query_conditions.append("from_user = %(user)s") query_where = "" if query_conditions: @@ -247,7 +255,7 @@ class StatsRunner(object): """ with self.engine.connect() as con: - df = pd.read_sql_query(query, con) + df = pd.read_sql_query(query, con, params=sql_dict) df['day'] = pd.to_datetime(df.day) df['day'] = df.day.dt.tz_convert(self.tz) @@ -290,17 +298,19 @@ class StatsRunner(object): :param end: End timestamp (e.g. 2019, 2019-01, 2019-01-01, "2019-01-01 14:21") """ query_conditions = [] + sql_dict = {} if start: - start_dt = pd.to_datetime(start) - query_conditions.append(f"date >= '{start_dt}'") + sql_dict['start_dt'] = pd.to_datetime(start) + query_conditions.append("date >= %(start_dt)s") if end: - end_dt = pd.to_datetime(end) - query_conditions.append(f"date < '{end_dt}'") + sql_dict['end_dt'] = pd.to_datetime(end) + query_conditions.append("date < %(end_dt)s") if user: - query_conditions.append(f"from_user = {user[0]}") + sql_dict['user'] = user[0] + query_conditions.append("from_user = %(user)s") query_where = "" if query_conditions: @@ -315,7 +325,7 @@ class StatsRunner(object): ORDER BY msg_time """ with self.engine.connect() as con: - df = pd.read_sql_query(query, con) + df = pd.read_sql_query(query, con, params=sql_dict) df['msg_time'] = pd.to_datetime(df.msg_time) df['msg_time'] = df.msg_time.dt.tz_convert('America/Toronto') @@ -333,7 +343,7 @@ class StatsRunner(object): ax = fig.subplots() sns.heatmap(df_grouped.T, yticklabels=['M', 'T', 'W', 'Th', 'F', 'Sa', 'Su'], linewidths=.5, - square=True, fmt='d', + square=True, fmt='d', vmin=0, cbar_kws={"orientation": "horizontal"}, cmap="viridis", ax=ax) ax.tick_params(axis='y', rotation=0) ax.set_ylabel("") @@ -359,19 +369,21 @@ class StatsRunner(object): :param end: End timestamp (e.g. 2019, 2019-01, 2019-01-01, "2019-01-01 14:21") """ query_conditions = [] + sql_dict = {} if averages is None: averages = 30 if start: - start_dt = pd.to_datetime(start) - query_conditions.append(f"date >= '{start_dt}'") + sql_dict['start_dt'] = pd.to_datetime(start) + query_conditions.append("date >= %(start_dt)s") if end: - end_dt = pd.to_datetime(end) - query_conditions.append(f"date < '{end_dt}'") + sql_dict['end_dt'] = pd.to_datetime(end) + query_conditions.append("date < %(end_dt)s") if user: - query_conditions.append(f"from_user = {user[0]}") + sql_dict['user'] = user[0] + query_conditions.append("from_user = %(user)s") query_where = "" if query_conditions: @@ -387,7 +399,7 @@ class StatsRunner(object): """ with self.engine.connect() as con: - df = pd.read_sql_query(query, con) + df = pd.read_sql_query(query, con, params=sql_dict) df['day'] = pd.to_datetime(df.day) df['day'] = df.day.dt.tz_convert(self.tz) if averages: