# 使用HMM建模股票价格波动

HMM的原理参考这里

## 准备数据

```    import datetime
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.finance import quotes_historical_yahoo
from matplotlib.dates import YearLocator, MonthLocator, DateFormatter
from sklearn.hmm import GaussianHMM

# parepare data
date1 = datetime.date(1995, 1, 1)  # start date
date2 = datetime.date(2012, 1, 6)  # end date
# get quotes from yahoo finance
quotes = quotes_historical_yahoo("INTC", date1, date2)
#data:
#Date	Open	Close High	Low		Volume
# unpack quotes
dates = np.array([q[0] for q in quotes], dtype=int)
close_v = np.array([q[2] for q in quotes])
volume = np.array([q[5] for q in quotes])[1:]
# take diff of close value
# this makes len(diff) = len(close_t) - 1
# therefore, others quantity also need to be shifted
diff = close_v[1:] - close_v[:-1]
dates = dates[1:]
close_v = close_v[1:]
# pack diff and volume for training
X = np.column_stack([diff, volume])
#plt.plot(close_v,volume,"bo")
```

## 拟合模型

```    #fit Gaussian HMM
n_components = 5
# make an HMM instance and execute fit
model = GaussianHMM(n_components, covariance_type="diag", n_iter=1000)
model.fit([X])
# predict the optimal sequence of internal hidden state
hidden_states = model.predict(X)
```

## 打印模型训练的结果并绘图

```    # print trained parameters and plot

print(np.round(model.transmat_,2))
print("means and vars of each hidden state")
for i in range(n_components):
print("%dth hidden state" % i)
print("mean = ", model.means_[i])
print("var = ",np.diag(model.covars_[i])   )
print()

years = YearLocator()   # every year
months = MonthLocator()  # every month
yearsFmt = DateFormatter('%Y')
fig = plt.figure()
for i in range(n_components):
# use fancy indexing to plot data in each state
idx = (hidden_states == i)
ax.plot_date(dates[idx], close_v[idx], 'o', label="%dth hidden state" % i)

# format the figure
ax.legend()
ax.xaxis.set_major_locator(years)
ax.xaxis.set_major_formatter(yearsFmt)
ax.xaxis.set_minor_locator(months)
ax.autoscale_view()
ax.fmt_xdata = DateFormatter('%Y-%m-%d')
ax.fmt_ydata = lambda x: '\$%1.2f' % x
ax.grid(True)
fig.autofmt_xdate()

#plt.plot()
```