Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Works on GPU #75

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,416 changes: 4,416 additions & 0 deletions OneProductData_GET.ipynb

Large diffs are not rendered by default.

510 changes: 510 additions & 0 deletions Plot.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion data_provider/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4
from data_provider.data_loader import Dataset_ean, Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4
from torch.utils.data import DataLoader

data_dict = {
@@ -10,6 +10,7 @@
'Traffic': Dataset_Custom,
'Weather': Dataset_Custom,
'm4': Dataset_M4,
'ean': Dataset_ean,
}


@@ -42,6 +43,7 @@ def data_provider(args, flag):
freq=freq,
seasonal_patterns=args.seasonal_patterns
)

else:
data_set = Data(
root_path=args.root_path,
74 changes: 74 additions & 0 deletions data_provider/data_loader.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,80 @@
warnings.filterwarnings('ignore')


class Dataset_ean(Dataset):
def __init__(self, root_path, flag='train', size=None,
features='S', data_path='data.csv',
target='sold_units', scale=False, timeenc=0, freq='W', percent=100,
seasonal_patterns=None):
if size is None:
self.seq_len = 13 # Use 0.25 year of data for sequence
self.label_len = 4 # Labels from last month
self.pred_len = 4 # Predict 1 month ahead
else:
self.seq_len, self.label_len, self.pred_len = size

assert flag in ['train', 'test', 'val']
self.set_type = {'train': 0, 'val': 1, 'test': 2}[flag]

self.features = features
self.target = target
self.scale = scale
self.timeenc = timeenc
self.freq = freq

self.root_path = root_path
self.data_path = data_path
self.__read_data__()
self.enc_in = self.data_x.shape[-1]
def __read_data__(self):
self.scaler = StandardScaler()
df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
df_raw.drop("Unnamed: 0", axis=1, inplace=True)# don't forget to remove this once regulated the data
num_weeks = len(df_raw)
num_train = int(num_weeks * 0.7)
num_val = int(num_weeks * 0.2)
num_test = num_weeks - num_train - num_val

border1s = [0, num_train, num_train + num_val]
border2s = [num_train, num_train + num_val, num_weeks]

border1 = border1s[self.set_type]
border2 = border2s[self.set_type]

if self.scale:
train_data = df_raw.iloc[:num_train]
self.scaler.fit(train_data[[self.target]].values)
data = self.scaler.transform(df_raw[[self.target]].values)
else:
data = df_raw[[self.target]].values

df_stamp = pd.to_datetime(df_raw.iloc[:, 0][border1:border2]) # 'end_date' is first column

time_features = np.vstack((df_stamp.dt.year, df_stamp.dt.month, df_stamp.dt.day, df_stamp.dt.weekday)).T


self.data_x = data[border1:border2]
self.data_y = data[border1:border2]
self.time_features = time_features



def __getitem__(self, index):
seq_x = self.data_x[index:index+self.seq_len]
seq_y = self.data_y[index+self.seq_len:index+self.seq_len+self.pred_len]
seq_x_mark = self.time_features[index:index+self.seq_len]
seq_y_mark = self.time_features[index+self.seq_len:index+self.seq_len+self.pred_len]
return seq_x, seq_y, seq_x_mark, seq_y_mark

def __len__(self):
return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in

def inverse_transform(self, data):
return self.scaler.inverse_transform(data)




class Dataset_ETT_hour(Dataset):
def __init__(self, root_path, flag='train', size=None,
features='S', data_path='ETTh1.csv',
222 changes: 222 additions & 0 deletions dataset/data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
,end_date,sold_units
19,2020-01-04,89.0
87,2020-01-11,89.0
100,2020-01-18,67.0
2,2020-01-25,82.0
104,2020-02-01,96.0
132,2020-02-08,113.0
95,2020-02-15,94.0
18,2020-02-22,69.0
67,2020-02-29,128.0
138,2020-03-07,192.0
79,2020-03-14,113.0
38,2020-03-21,104.0
96,2020-03-28,111.0
131,2020-04-04,74.0
11,2020-04-11,21.0
119,2020-04-18,4.0
108,2020-04-25,0.0
133,2020-05-02,43.0
53,2020-05-09,321.0
117,2020-05-16,237.0
144,2020-05-23,217.0
48,2020-05-30,234.0
101,2020-06-06,356.0
137,2020-06-13,317.0
85,2020-06-20,415.0
156,2020-06-27,972.0
31,2020-07-04,313.0
115,2020-07-11,317.0
33,2020-07-18,306.0
116,2020-07-25,216.0
97,2020-08-01,215.0
149,2020-08-08,211.0
21,2020-08-15,167.0
112,2020-08-22,148.0
37,2020-08-29,191.0
35,2020-09-05,169.0
110,2020-09-12,244.0
56,2020-09-19,146.0
126,2020-09-26,176.0
29,2020-10-03,381.0
118,2020-10-10,127.0
123,2020-10-17,88.0
98,2020-10-24,89.0
127,2020-10-31,200.0
154,2020-11-07,156.0
151,2020-11-14,254.0
45,2020-11-21,260.0
143,2020-11-28,243.0
125,2020-12-05,254.0
140,2020-12-12,191.0
102,2020-12-19,191.0
105,2020-12-26,87.0
41,2021-01-02,334.0
83,2021-01-09,171.0
77,2021-01-16,182.0
46,2021-01-23,163.0
129,2021-01-30,135.0
94,2021-02-06,177.0
71,2021-02-13,159.0
139,2021-02-20,370.0
9,2021-02-27,277.0
20,2021-03-06,332.0
136,2021-03-13,185.0
15,2021-03-20,157.0
36,2021-03-27,196.0
121,2021-04-03,226.0
14,2021-04-10,159.0
43,2021-04-17,381.0
44,2021-04-24,147.0
153,2021-05-01,117.0
141,2021-05-08,130.0
122,2021-05-15,108.0
23,2021-05-22,86.0
32,2021-05-29,64.0
1,2021-06-05,99.0
60,2021-06-12,149.0
73,2021-06-19,158.0
130,2021-06-26,130.0
106,2021-07-03,175.0
72,2021-07-10,128.0
148,2021-07-17,101.0
59,2021-07-24,92.0
91,2021-07-31,129.0
4,2021-08-07,192.0
47,2021-08-14,190.0
27,2021-08-21,221.0
13,2021-08-28,123.0
24,2021-09-04,132.0
128,2021-09-11,98.0
113,2021-09-18,87.0
150,2021-09-25,97.0
25,2021-10-02,112.0
42,2021-10-09,101.0
30,2021-10-16,112.0
86,2021-10-23,85.0
8,2021-10-30,91.0
12,2021-11-06,83.0
78,2021-11-13,128.0
0,2021-11-20,112.0
50,2021-11-27,164.0
84,2021-12-04,331.0
28,2021-12-11,162.0
7,2021-12-18,113.0
92,2021-12-25,74.0
114,2022-01-01,92.0
40,2022-01-08,107.0
76,2022-01-15,100.0
63,2022-01-22,77.0
34,2022-01-29,94.0
80,2022-02-05,83.0
3,2022-02-12,135.0
88,2022-02-19,106.0
107,2022-02-26,108.0
147,2022-03-05,126.0
145,2022-03-12,139.0
146,2022-03-19,104.0
5,2022-03-26,112.0
90,2022-04-02,115.0
6,2022-04-09,121.0
109,2022-04-16,102.0
152,2022-04-23,180.0
135,2022-04-30,286.0
54,2022-05-07,187.0
61,2022-05-14,171.0
89,2022-05-21,141.0
134,2022-05-28,157.0
66,2022-06-04,113.0
10,2022-06-11,138.0
81,2022-06-18,149.0
75,2022-06-25,154.0
55,2022-07-02,208.0
111,2022-07-09,259.0
93,2022-07-16,224.0
26,2022-07-23,198.0
52,2022-07-30,230.0
58,2022-08-06,102.0
120,2022-08-13,152.0
22,2022-08-20,184.0
70,2022-08-27,189.0
49,2022-09-03,153.0
69,2022-09-10,130.0
103,2022-09-17,135.0
155,2022-09-24,176.0
39,2022-10-01,180.0
68,2022-10-08,246.0
16,2022-10-15,348.0
57,2022-10-22,392.0
99,2022-10-29,435.0
51,2022-11-05,276.0
124,2022-11-12,188.0
17,2022-11-19,169.0
82,2022-11-26,360.0
74,2022-12-03,420.0
65,2022-12-10,219.0
64,2022-12-17,207.0
62,2022-12-24,150.0
142,2022-12-31,131.0
179,2023-01-07,184.0
211,2023-01-14,139.0
160,2023-01-21,126.0
159,2023-01-28,146.0
196,2023-02-04,168.0
212,2023-02-11,140.0
169,2023-02-18,136.0
180,2023-02-25,96.0
168,2023-03-04,124.0
220,2023-03-11,153.0
224,2023-03-18,160.0
206,2023-03-25,145.0
185,2023-04-01,184.0
190,2023-04-08,164.0
188,2023-04-15,143.0
186,2023-04-22,166.0
213,2023-04-29,144.0
230,2023-05-06,131.0
227,2023-05-13,161.0
222,2023-05-20,219.0
187,2023-05-27,239.0
229,2023-06-03,307.0
162,2023-06-10,291.0
171,2023-06-17,251.0
191,2023-06-24,187.0
183,2023-07-01,237.0
199,2023-07-08,204.0
209,2023-07-15,192.0
172,2023-07-22,180.0
167,2023-07-29,152.0
192,2023-08-05,177.0
203,2023-08-12,135.0
181,2023-08-19,170.0
184,2023-08-26,116.0
197,2023-09-02,162.0
202,2023-09-09,193.0
231,2023-09-16,152.0
182,2023-09-23,164.0
174,2023-09-30,196.0
223,2023-10-07,200.0
175,2023-10-14,196.0
201,2023-10-21,143.0
232,2023-10-28,250.0
207,2023-11-04,278.0
177,2023-11-11,508.0
210,2023-11-18,148.0
195,2023-11-25,244.0
173,2023-12-02,488.0
204,2023-12-09,209.0
208,2023-12-16,201.0
163,2023-12-23,164.0
164,2023-12-30,154.0
165,2024-01-06,200.0
215,2024-01-13,181.0
161,2024-01-20,143.0
218,2024-01-27,238.0
226,2024-02-03,190.0
233,2024-02-10,145.0
216,2024-02-17,126.0
194,2024-02-24,198.0
214,2024-03-02,149.0
219,2024-03-09,148.0
157,2024-03-16,146.0
221,2024-03-23,100.0
1 change: 1 addition & 0 deletions dataset/prompt_bank/ean.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EAN is a dataset contains only sold_units of one ean product of loreal cosmetic products.
Loading