Skip to content

Commit ef98a3d

Browse files
committed
added script.py
1 parent 2314225 commit ef98a3d

File tree

1 file changed

+158
-0
lines changed
  • programming/projects/oldest-businesses

1 file changed

+158
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
import seaborn as sns
5+
6+
7+
class BusinessDataAnalyzer():
8+
9+
_CLASS_VARS = ['business_csv_path', 'newbusiness_csv_path', 'countries_csv_path', 'categories_csv_path']
10+
11+
def __init__(self, **kwargs):
12+
self._mapping = {key : value for key, value in kwargs.items() if key in self._CLASS_VARS}
13+
try:
14+
assert len(self._mapping) == len(self._CLASS_VARS)
15+
except AssertionError as e:
16+
raise(AssertionError(f'One or more input variables are missing! {e}'))
17+
18+
def _load_csv(self, path : str) -> pd.read_csv:
19+
'''Reads input csv from the path specified
20+
Returns pandas df.
21+
:param path: [str] csv path
22+
:return: pd.DataFrame
23+
'''
24+
return pd.read_csv(path)
25+
26+
def _group_by(self, dataframe : pd.DataFrame,
27+
columns : list,
28+
aggfunc : dict = None) -> pd.DataFrame.groupby:
29+
'''
30+
Takes pandas df, column name(s) & aggregate function(optional) as inputs
31+
Returns pandas.DataFrame.groupby object.
32+
:param dataframe: pandas df
33+
:param columns: [str] column(s) criteria
34+
:param aggfunc: [dict] calculation metric(mean, median etc) [optional]
35+
:return: pandas.DataFrame.groupby object
36+
'''
37+
try:
38+
if aggfunc:
39+
return dataframe.groupby(by=columns).agg(aggfunc)
40+
else:
41+
return dataframe.groupby(by=columns)
42+
except Exception as e:
43+
print(e)
44+
return None
45+
46+
def _merge(self, dataframe : pd.DataFrame,
47+
dataframe_to_merge : pd.DataFrame,
48+
on : list,
49+
how : str = None,
50+
indicator : bool = False) -> pd.DataFrame:
51+
'''
52+
Merges 2 pandas dfs
53+
Returns pandas df.
54+
:param dataframe: first dataframe
55+
:param dataframe_to_merge: second dataframe
56+
:param on: [str] column name to merge on
57+
:param how: [str] merge type('outer', 'left', 'right' etc) [optional]
58+
:param indicator: [bool] adds info on source of each row [optional]
59+
:return:
60+
'''
61+
if how or indicator:
62+
if how:
63+
if indicator: return dataframe.merge(dataframe_to_merge, on=on, how=how, indicator=True)
64+
else: return dataframe.merge(dataframe_to_merge, on=on, how=how)
65+
else:
66+
return dataframe.merge(dataframe_to_merge, on=on, indicator=indicator)
67+
else:
68+
return dataframe.merge(dataframe_to_merge, on=on)
69+
70+
def _sort_data(self, dataframe : pd.DataFrame,
71+
sort_by : list = None,
72+
ascending : bool = False) -> pd.DataFrame:
73+
'''
74+
Sort pandas df by one or more columns.
75+
:param dataframe: pandas df
76+
:param sort_by: [list(str)] list of column or columns to sort by
77+
:param ascending: [bool] sorted data in ascending order [optional]
78+
:return: pandas df
79+
'''
80+
return dataframe.sort_values(by=sort_by, ascending=ascending)
81+
82+
def _plot_data(self, dataframe : pd.DataFrame,
83+
x : str = None,
84+
y : str = None,
85+
kind : str ='count',
86+
col : str = None,
87+
col_wrap : int = None,
88+
hue=None) -> sns.catplot:
89+
'''
90+
Creates plot based on user input.
91+
If no y label is provided, countplot is generated.
92+
:param dataframe: pandas df
93+
:param x: [str] x-axis variable [optional]
94+
:param y: [str] y-axis variable [optional]
95+
:param kind: [str] plot type [optional]
96+
:param col: [str] split visualisation into multiple plots based on column value [optional]
97+
:param col_wrap: number of plots per column [optional]
98+
:return: sns.catplot
99+
'''
100+
vars = locals()
101+
del vars['self']
102+
del vars['dataframe']
103+
104+
func_args = ', '.join([f"{key}='{value}'" if value and isinstance(value, str) else f"{key}={value}" for key, value in vars.items()])
105+
return eval(f'sns.catplot(data=dataframe, {func_args})')
106+
107+
108+
def analyze(self):
109+
'''
110+
:return:
111+
'''
112+
businesses = self._load_csv(self._mapping['business_csv_path'])
113+
countries = self._load_csv(self._mapping['countries_csv_path'])
114+
new_businesses = self._load_csv(self._mapping['newbusiness_csv_path'])
115+
categories = self._load_csv(self._mapping['categories_csv_path'])
116+
117+
businesses_categories = self._merge(businesses, categories, on=['category_code'], how='outer')
118+
count_business_cats = self._group_by(businesses_categories, columns=['category'], aggfunc={'year_founded' : 'count'})
119+
count_business_cats.columns = ['count_business_cats']
120+
print(count_business_cats)
121+
122+
businesses_categories['category_code'].unique()
123+
old_restaurants = businesses_categories.query('category_code == "CAT4" & year_founded < 1800')
124+
125+
old_restaurants = self._sort_data(old_restaurants, sort_by=['year_founded'])
126+
print(old_restaurants)
127+
128+
businesses_categories_temp = self._merge(businesses, categories, on=['category_code'], how='inner')
129+
businesses_categories_countries = self._merge(businesses_categories_temp,
130+
countries,
131+
on=['country_code'],
132+
how='inner')
133+
134+
businesses_categories_countries = self._sort_data(businesses_categories_countries,
135+
sort_by=['year_founded'],
136+
ascending=True)
137+
138+
plot_businesses = self._plot_data(dataframe=businesses_categories_countries,
139+
y='category',
140+
col='continent',
141+
col_wrap=2)
142+
plot_businesses.fig.suptitle('No. of businesses per category')
143+
plt.show()
144+
145+
oldest_by_continent_category = self._group_by(businesses_categories_countries,
146+
columns=['continent', 'category']).agg({'year_founded' : 'min'})
147+
148+
149+
print(oldest_by_continent_category.head())
150+
151+
152+
csvs = {'business_csv_path' : 'data/businesses.csv',
153+
'newbusiness_csv_path' : 'data/new_businesses.csv',
154+
'countries_csv_path' : 'data/countries.csv',
155+
'categories_csv_path' : 'data/categories.csv'}
156+
157+
obj = BusinessDataAnalyzer(**csvs)
158+
obj.analyze()

0 commit comments

Comments
 (0)