更改

跳到导航 跳到搜索
添加473字节 、 2024年5月13日 (星期一)
第93行: 第93行:     
=代码=
 
=代码=
见其他任务
+
python:
 +
 
 +
def tpm_ei(tpm, log_base = 2):
 +
    # marginal distribution of y given x ~ Unifrom Dist
 +
    puy = tpm.sum(axis=0)
 +
    n = tpm.shape[0]
 +
    # replace 0 to a small positive number to avoid log error
 +
    eps = 1E-10
 +
    tpm_e = np.where(tpm==0, eps, tpm)
 +
    puy_e = np.where(tpm==0, eps, puy)
 +
   
 +
    # calculate EI of specific x
 +
    ei_x = (np.log2(n * tpm_e / puy_e) / np.log2(log_base)  * tpm).sum(axis=1)
 +
   
 +
    # calculate total EI
 +
    ei_all = ei_x.mean()
 +
    return ei_all
    
=参考文献=
 
=参考文献=
207

个编辑

导航菜单