更改

跳到导航 跳到搜索
添加35字节 、 2024年5月26日 (星期日)
第219行: 第219行:     
=代码=
 
=代码=
python:
+
python:<syntaxhighlight lang="python3">
 
+
def tpm_ei(tpm, log_base = 2):
def tpm_ei(tpm, log_base = 2):
+
  # marginal distribution of y given x ~ Unifrom Dist
    # marginal distribution of y given x ~ Unifrom Dist
+
  puy = tpm.sum(axis=0)
    puy = tpm.sum(axis=0)
+
  n = tpm.shape[0]
    n = tpm.shape[0]
+
  # replace 0 to a small positive number to avoid log error
    # replace 0 to a small positive number to avoid log error
+
  eps = 1E-10
    eps = 1E-10
+
  tpm_e = np.where(tpm==0, eps, tpm)
    tpm_e = np.where(tpm==0, eps, tpm)
+
  puy_e = np.where(tpm==0, eps, puy)
    puy_e = np.where(tpm==0, eps, puy)
+
 
   
+
  # calculate EI of specific x
    # calculate EI of specific x
+
  ei_x = (np.log2(n * tpm_e / puy_e) / np.log2(log_base)  * tpm).sum(axis=1)
    ei_x = (np.log2(n * tpm_e / puy_e) / np.log2(log_base)  * tpm).sum(axis=1)
+
 
   
+
  # calculate total EI
    # calculate total EI
+
  ei_all = ei_x.mean()
    ei_all = ei_x.mean()
+
  return ei_all
    return ei_all
+
</syntaxhighlight>
    
=参考文献=
 
=参考文献=
332

个编辑

导航菜单