← 返回题库
初级

第18章 概率潜在语义分析 - 习题18.3 - 实现PLSA类

未完成
初级参考 完整示例代码供参考,建议自己理解后重新输入
class PLSA:
    def __init__(self, K, max_iter):
        self.K = K
        self.max_iter = max_iter

    def fit(self, X):
        n_d, n_w = X.shape

        # P(z|w,d)
        p_z_dw = np.zeros((n_d, n_w, self.K))

        # P(z|d)
        p_z_d = np.random.rand(n_d, self.K)

        # P(w|z)
        p_w_z = np.random.rand(self.K, n_w)

        for i_iter in range(self.max_iter):
            # E step
            for di in range(n_d):
                for wi in range(n_w):
                    sum_zk = np.zeros((self.K))
                    for zi in range(self.K):
                        sum_zk[zi] = p_z_d[di, zi] * p_w_z[zi, wi]
                    sum1 = np.sum(sum_zk)
                    if sum1 == 0:
                        sum1 = 1
                    for zi in range(self.K):
                        p_z_dw[di, wi, zi] = sum_zk[zi] / sum1

            # M step

            # update P(z|d)
            for di in range(n_d):
                for zi in range(self.K):
                    sum1 = 0.
                    sum2 = 0.

                    for wi in range(n_w):
                        sum1 = sum1 + X[di, wi] * p_z_dw[di, wi, zi]
                        sum2 = sum2 + X[di, wi]

                    if sum2 == 0:
                        sum2 = 1
                    p_z_d[di, zi] = sum1 / sum2

            # update P(w|z)
            for zi in range(self.K):
                sum2 = np.zeros((n_w))
                for wi in range(n_w):
                    for di in range(n_d):
                        sum2[wi] = sum2[wi] + X[di, wi] * p_z_dw[di, wi, zi]
                sum1 = np.sum(sum2)
                if sum1 == 0:
                    sum1 = 1
                    for wi in range(n_w):
                        p_w_z[zi, wi] = sum2[wi] / sum1

        return p_w_z, p_z_d


# https://github.com/lipiji/PG_PLSA/blob/master/plsa.py
Python 代码 🔒 登录后使用
🔒

登录后即可练习

注册免费账号,在浏览器中直接运行 Python 代码