← 返回题库
初级

第11章 条件随机场 - 实现CRF模型类

未完成
初级参考 完整示例代码供参考,建议自己理解后重新输入
import numpy as np

class LinearChainCRF:
    def __init__(self, num_states, num_features):
        self.num_states = num_states
        self.num_features = num_features
        self.transition_weights = np.random.randn(num_states, num_states)
        self.emission_weights = np.random.randn(num_states, num_features)
        self.start_weights = np.random.randn(num_states)
        self.end_weights = np.random.randn(num_states)
    
    def compute_log_potential(self, features, labels):
        score = self.start_weights[labels[0]]
        for i in range(len(labels)):
            score += np.dot(self.emission_weights[labels[i]], features[i])
            if i > 0:
                score += self.transition_weights[labels[i-1], labels[i]]
        score += self.end_weights[labels[-1]]
        return score
    
    def viterbi_decode(self, features):
        seq_len = len(features)
        viterbi = np.zeros((seq_len, self.num_states))
        backpointer = np.zeros((seq_len, self.num_states), dtype=int)
        
        viterbi[0] = self.start_weights + np.dot(self.emission_weights, features[0])
        
        for t in range(1, seq_len):
            for s in range(self.num_states):
                scores = viterbi[t-1] + self.transition_weights[:, s]
                backpointer[t, s] = np.argmax(scores)
                viterbi[t, s] = np.max(scores) + np.dot(self.emission_weights[s], features[t])
        
        best_path = [np.argmax(viterbi[-1] + self.end_weights)]
        for t in range(seq_len - 1, 0, -1):
            best_path.insert(0, backpointer[t, best_path[0]])
        return best_path
Python 代码 🔒 登录后使用
🔒

登录后即可练习

注册免费账号,在浏览器中直接运行 Python 代码