|
| 1 | + |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import pandas as pd |
| 4 | +import seaborn as sns |
| 5 | + |
| 6 | + |
| 7 | +class BusinessDataAnalyzer(): |
| 8 | + |
| 9 | + _CLASS_VARS = ['business_csv_path', 'newbusiness_csv_path', 'countries_csv_path', 'categories_csv_path'] |
| 10 | + |
| 11 | + def __init__(self, **kwargs): |
| 12 | + self._mapping = {key : value for key, value in kwargs.items() if key in self._CLASS_VARS} |
| 13 | + try: |
| 14 | + assert len(self._mapping) == len(self._CLASS_VARS) |
| 15 | + except AssertionError as e: |
| 16 | + raise(AssertionError(f'One or more input variables are missing! {e}')) |
| 17 | + |
| 18 | + def _load_csv(self, path : str) -> pd.read_csv: |
| 19 | + '''Reads input csv from the path specified |
| 20 | + Returns pandas df. |
| 21 | + :param path: [str] csv path |
| 22 | + :return: pd.DataFrame |
| 23 | + ''' |
| 24 | + return pd.read_csv(path) |
| 25 | + |
| 26 | + def _group_by(self, dataframe : pd.DataFrame, |
| 27 | + columns : list, |
| 28 | + aggfunc : dict = None) -> pd.DataFrame.groupby: |
| 29 | + ''' |
| 30 | + Takes pandas df, column name(s) & aggregate function(optional) as inputs |
| 31 | + Returns pandas.DataFrame.groupby object. |
| 32 | + :param dataframe: pandas df |
| 33 | + :param columns: [str] column(s) criteria |
| 34 | + :param aggfunc: [dict] calculation metric(mean, median etc) [optional] |
| 35 | + :return: pandas.DataFrame.groupby object |
| 36 | + ''' |
| 37 | + try: |
| 38 | + if aggfunc: |
| 39 | + return dataframe.groupby(by=columns).agg(aggfunc) |
| 40 | + else: |
| 41 | + return dataframe.groupby(by=columns) |
| 42 | + except Exception as e: |
| 43 | + print(e) |
| 44 | + return None |
| 45 | + |
| 46 | + def _merge(self, dataframe : pd.DataFrame, |
| 47 | + dataframe_to_merge : pd.DataFrame, |
| 48 | + on : list, |
| 49 | + how : str = None, |
| 50 | + indicator : bool = False) -> pd.DataFrame: |
| 51 | + ''' |
| 52 | + Merges 2 pandas dfs |
| 53 | + Returns pandas df. |
| 54 | + :param dataframe: first dataframe |
| 55 | + :param dataframe_to_merge: second dataframe |
| 56 | + :param on: [str] column name to merge on |
| 57 | + :param how: [str] merge type('outer', 'left', 'right' etc) [optional] |
| 58 | + :param indicator: [bool] adds info on source of each row [optional] |
| 59 | + :return: |
| 60 | + ''' |
| 61 | + if how or indicator: |
| 62 | + if how: |
| 63 | + if indicator: return dataframe.merge(dataframe_to_merge, on=on, how=how, indicator=True) |
| 64 | + else: return dataframe.merge(dataframe_to_merge, on=on, how=how) |
| 65 | + else: |
| 66 | + return dataframe.merge(dataframe_to_merge, on=on, indicator=indicator) |
| 67 | + else: |
| 68 | + return dataframe.merge(dataframe_to_merge, on=on) |
| 69 | + |
| 70 | + def _sort_data(self, dataframe : pd.DataFrame, |
| 71 | + sort_by : list = None, |
| 72 | + ascending : bool = False) -> pd.DataFrame: |
| 73 | + ''' |
| 74 | + Sort pandas df by one or more columns. |
| 75 | + :param dataframe: pandas df |
| 76 | + :param sort_by: [list(str)] list of column or columns to sort by |
| 77 | + :param ascending: [bool] sorted data in ascending order [optional] |
| 78 | + :return: pandas df |
| 79 | + ''' |
| 80 | + return dataframe.sort_values(by=sort_by, ascending=ascending) |
| 81 | + |
| 82 | + def _plot_data(self, dataframe : pd.DataFrame, |
| 83 | + x : str = None, |
| 84 | + y : str = None, |
| 85 | + kind : str ='count', |
| 86 | + col : str = None, |
| 87 | + col_wrap : int = None, |
| 88 | + hue=None) -> sns.catplot: |
| 89 | + ''' |
| 90 | + Creates plot based on user input. |
| 91 | + If no y label is provided, countplot is generated. |
| 92 | + :param dataframe: pandas df |
| 93 | + :param x: [str] x-axis variable [optional] |
| 94 | + :param y: [str] y-axis variable [optional] |
| 95 | + :param kind: [str] plot type [optional] |
| 96 | + :param col: [str] split visualisation into multiple plots based on column value [optional] |
| 97 | + :param col_wrap: number of plots per column [optional] |
| 98 | + :return: sns.catplot |
| 99 | + ''' |
| 100 | + vars = locals() |
| 101 | + del vars['self'] |
| 102 | + del vars['dataframe'] |
| 103 | + |
| 104 | + func_args = ', '.join([f"{key}='{value}'" if value and isinstance(value, str) else f"{key}={value}" for key, value in vars.items()]) |
| 105 | + return eval(f'sns.catplot(data=dataframe, {func_args})') |
| 106 | + |
| 107 | + |
| 108 | + def analyze(self): |
| 109 | + ''' |
| 110 | + :return: |
| 111 | + ''' |
| 112 | + businesses = self._load_csv(self._mapping['business_csv_path']) |
| 113 | + countries = self._load_csv(self._mapping['countries_csv_path']) |
| 114 | + new_businesses = self._load_csv(self._mapping['newbusiness_csv_path']) |
| 115 | + categories = self._load_csv(self._mapping['categories_csv_path']) |
| 116 | + |
| 117 | + businesses_categories = self._merge(businesses, categories, on=['category_code'], how='outer') |
| 118 | + count_business_cats = self._group_by(businesses_categories, columns=['category'], aggfunc={'year_founded' : 'count'}) |
| 119 | + count_business_cats.columns = ['count_business_cats'] |
| 120 | + print(count_business_cats) |
| 121 | + |
| 122 | + businesses_categories['category_code'].unique() |
| 123 | + old_restaurants = businesses_categories.query('category_code == "CAT4" & year_founded < 1800') |
| 124 | + |
| 125 | + old_restaurants = self._sort_data(old_restaurants, sort_by=['year_founded']) |
| 126 | + print(old_restaurants) |
| 127 | + |
| 128 | + businesses_categories_temp = self._merge(businesses, categories, on=['category_code'], how='inner') |
| 129 | + businesses_categories_countries = self._merge(businesses_categories_temp, |
| 130 | + countries, |
| 131 | + on=['country_code'], |
| 132 | + how='inner') |
| 133 | + |
| 134 | + businesses_categories_countries = self._sort_data(businesses_categories_countries, |
| 135 | + sort_by=['year_founded'], |
| 136 | + ascending=True) |
| 137 | + |
| 138 | + plot_businesses = self._plot_data(dataframe=businesses_categories_countries, |
| 139 | + y='category', |
| 140 | + col='continent', |
| 141 | + col_wrap=2) |
| 142 | + plot_businesses.fig.suptitle('No. of businesses per category') |
| 143 | + plt.show() |
| 144 | + |
| 145 | + oldest_by_continent_category = self._group_by(businesses_categories_countries, |
| 146 | + columns=['continent', 'category']).agg({'year_founded' : 'min'}) |
| 147 | + |
| 148 | + |
| 149 | + print(oldest_by_continent_category.head()) |
| 150 | + |
| 151 | + |
| 152 | +csvs = {'business_csv_path' : 'data/businesses.csv', |
| 153 | + 'newbusiness_csv_path' : 'data/new_businesses.csv', |
| 154 | + 'countries_csv_path' : 'data/countries.csv', |
| 155 | + 'categories_csv_path' : 'data/categories.csv'} |
| 156 | + |
| 157 | +obj = BusinessDataAnalyzer(**csvs) |
| 158 | +obj.analyze() |
0 commit comments