theonegareth commited on
Commit
9a59bc2
·
verified ·
1 Parent(s): 5b17243

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +157 -0
inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for Gold Price Direction Predictor
3
+
4
+ This script demonstrates how to load the model and make predictions.
5
+ """
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ from joblib import load
10
+ from huggingface_hub import hf_hub_download
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ def load_model():
16
+ """Load the trained model from Hugging Face"""
17
+ try:
18
+ model_path = hf_hub_download("theonegareth/GoldPricePredictor", "gold_direction_model.joblib")
19
+ model = load(model_path)
20
+ print("Model loaded successfully!")
21
+ return model
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ return None
25
+
26
+
27
+ def add_features_adaptive(data: pd.DataFrame, price='close') -> pd.DataFrame:
28
+ """
29
+ Feature engineering function (same as used in training)
30
+ """
31
+ out = data.copy()
32
+ n = len(out)
33
+
34
+ if n < 8:
35
+ raise ValueError(f"Dataset too small (n={n}). Need at least 8 rows.")
36
+
37
+ out['ret'] = out[price].pct_change()
38
+ out['log_ret'] = np.log1p(out['ret'])
39
+
40
+ # Adaptive lags and windows
41
+ max_lag = max(1, min(5, n // 6))
42
+ lag_list = list(range(1, max_lag + 1))
43
+ win_candidates = [3, 5, 10, 20]
44
+ win_list = [w for w in win_candidates if w < n-2]
45
+ if not win_list:
46
+ win_list = [3]
47
+
48
+ for L in lag_list:
49
+ out[f'ret_lag_{L}'] = out['ret'].shift(L)
50
+
51
+ for w in win_list:
52
+ out[f'roll_mean_{w}'] = out['ret'].rolling(w, min_periods=1).mean()
53
+ out[f'roll_std_{w}'] = out['ret'].rolling(w, min_periods=1).std()
54
+ out[f'roll_min_{w}'] = out['ret'].rolling(w, min_periods=1).min()
55
+ out[f'roll_max_{w}'] = out['ret'].rolling(w, min_periods=1).max()
56
+
57
+ # RSI
58
+ rsi_w = max(3, min(14, n // 6))
59
+ delta = out[price].diff()
60
+ gain = (delta.where(delta > 0, 0.0)).rolling(rsi_w, min_periods=1).mean()
61
+ loss = (-delta.where(delta < 0, 0.0)).rolling(rsi_w, min_periods=1).mean()
62
+ rs = gain / (loss + 1e-9)
63
+ out['rsi14'] = 100 - (100 / (1 + rs))
64
+
65
+ # MACD
66
+ fast = max(6, min(12, n // 5))
67
+ slow = max(fast+4, min(26, n // 3))
68
+ signal = max(5, min(9, n // 6))
69
+ ema_fast = out[price].ewm(span=fast, adjust=False).mean()
70
+ ema_slow = out[price].ewm(span=slow, adjust=False).mean()
71
+ out['macd'] = ema_fast - ema_slow
72
+ out['macd_signal'] = out['macd'].ewm(span=signal, adjust=False).mean()
73
+ out['macd_hist'] = out['macd'] - out['macd_signal']
74
+
75
+ # Bollinger
76
+ bb_w = max(5, min(20, n // 4))
77
+ ma = out[price].rolling(bb_w, min_periods=1).mean()
78
+ sd = out[price].rolling(bb_w, min_periods=1).std()
79
+ out['bb_mid'] = ma
80
+ out['bb_up'] = ma + 2*sd
81
+ out['bb_low'] = ma - 2*sd
82
+ out['bb_width'] = (out['bb_up'] - out['bb_low']) / (out['bb_mid'] + 1e-9)
83
+
84
+ # Calendar
85
+ out['dow'] = out['date'].dt.weekday
86
+ out['month'] = out['date'].dt.month
87
+
88
+ return out
89
+
90
+
91
+ def predict_next_day_direction(model, historical_data: pd.DataFrame, threshold=0.52):
92
+ """
93
+ Predict next-day direction from historical price data
94
+
95
+ Parameters:
96
+ - model: Loaded sklearn model
97
+ - historical_data: DataFrame with 'date' and 'close' columns
98
+ - threshold: Probability threshold for prediction (optimized from training)
99
+
100
+ Returns:
101
+ - prediction: 1 for up, 0 for down
102
+ - probability: Probability of going up
103
+ """
104
+ # Ensure data is sorted
105
+ historical_data = historical_data.sort_values('date').reset_index(drop=True)
106
+
107
+ # Add features
108
+ feat = add_features_adaptive(historical_data, price='close')
109
+
110
+ # Drop rows with NaN (lags, etc.)
111
+ feat = feat.dropna(subset=[c for c in feat.columns if c.startswith('ret_lag_')])
112
+
113
+ if len(feat) == 0:
114
+ raise ValueError("Not enough data to compute features")
115
+
116
+ # Get latest features
117
+ latest_features = feat.iloc[[-1]]
118
+
119
+ # Select feature columns (exclude non-feature columns)
120
+ feature_cols = [c for c in latest_features.columns
121
+ if c not in ['date','close','ret','log_ret','next_close','target']
122
+ and not c.startswith('roll_') or c in ['roll_mean_3','roll_std_3','roll_min_3','roll_max_3',
123
+ 'roll_mean_5','roll_std_5','roll_min_5','roll_max_5']]
124
+
125
+ # Ensure we have the right columns (this might need adjustment based on training)
126
+ X = latest_features[feature_cols]
127
+
128
+ # Predict
129
+ proba_up = model.predict_proba(X)[:, 1][0]
130
+ prediction = int(proba_up >= threshold)
131
+
132
+ direction = "UP 📈" if prediction == 1 else "DOWN 📉"
133
+
134
+ return prediction, proba_up, direction
135
+
136
+
137
+ if __name__ == "__main__":
138
+ # Example usage
139
+ model = load_model()
140
+
141
+ if model is None:
142
+ print("Failed to load model")
143
+ exit(1)
144
+
145
+ # Example historical data (replace with your data)
146
+ example_data = pd.DataFrame({
147
+ 'date': pd.date_range('2023-01-01', periods=50, freq='D'),
148
+ 'close': np.random.uniform(1000000, 1200000, 50) # Random prices
149
+ })
150
+
151
+ try:
152
+ pred, proba, direction = predict_next_day_direction(model, example_data)
153
+ print(f"Next-day prediction: {direction}")
154
+ print(".3f")
155
+ print(f"Decision threshold: 0.52")
156
+ except Exception as e:
157
+ print(f"Error making prediction: {e}")