第219行: |
第219行: |
| | | |
| =代码= | | =代码= |
− | python: | + | python:<syntaxhighlight lang="python3"> |
− | | + | def tpm_ei(tpm, log_base = 2): |
− | def tpm_ei(tpm, log_base = 2):
| + | # marginal distribution of y given x ~ Unifrom Dist |
− | # marginal distribution of y given x ~ Unifrom Dist
| + | puy = tpm.sum(axis=0) |
− | puy = tpm.sum(axis=0)
| + | n = tpm.shape[0] |
− | n = tpm.shape[0]
| + | # replace 0 to a small positive number to avoid log error |
− | # replace 0 to a small positive number to avoid log error
| + | eps = 1E-10 |
− | eps = 1E-10
| + | tpm_e = np.where(tpm==0, eps, tpm) |
− | tpm_e = np.where(tpm==0, eps, tpm)
| + | puy_e = np.where(tpm==0, eps, puy) |
− | puy_e = np.where(tpm==0, eps, puy)
| + | |
− |
| + | # calculate EI of specific x |
− | # calculate EI of specific x
| + | ei_x = (np.log2(n * tpm_e / puy_e) / np.log2(log_base) * tpm).sum(axis=1) |
− | ei_x = (np.log2(n * tpm_e / puy_e) / np.log2(log_base) * tpm).sum(axis=1)
| + | |
− |
| + | # calculate total EI |
− | # calculate total EI
| + | ei_all = ei_x.mean() |
− | ei_all = ei_x.mean()
| + | return ei_all |
− | return ei_all
| + | </syntaxhighlight> |
| | | |
| =参考文献= | | =参考文献= |