初级
第11章 条件随机场 - 实现CRF模型类
未完成
初级参考
完整示例代码供参考,建议自己理解后重新输入
import numpy as np
class LinearChainCRF:
def __init__(self, num_states, num_features):
self.num_states = num_states
self.num_features = num_features
self.transition_weights = np.random.randn(num_states, num_states)
self.emission_weights = np.random.randn(num_states, num_features)
self.start_weights = np.random.randn(num_states)
self.end_weights = np.random.randn(num_states)
def compute_log_potential(self, features, labels):
score = self.start_weights[labels[0]]
for i in range(len(labels)):
score += np.dot(self.emission_weights[labels[i]], features[i])
if i > 0:
score += self.transition_weights[labels[i-1], labels[i]]
score += self.end_weights[labels[-1]]
return score
def viterbi_decode(self, features):
seq_len = len(features)
viterbi = np.zeros((seq_len, self.num_states))
backpointer = np.zeros((seq_len, self.num_states), dtype=int)
viterbi[0] = self.start_weights + np.dot(self.emission_weights, features[0])
for t in range(1, seq_len):
for s in range(self.num_states):
scores = viterbi[t-1] + self.transition_weights[:, s]
backpointer[t, s] = np.argmax(scores)
viterbi[t, s] = np.max(scores) + np.dot(self.emission_weights[s], features[t])
best_path = [np.argmax(viterbi[-1] + self.end_weights)]
for t in range(seq_len - 1, 0, -1):
best_path.insert(0, backpointer[t, best_path[0]])
return best_path
👑
升级 VIP
解锁全部题目,畅通无阻地学习
- ✓ 解锁全部训练包所有题目
- ✓ 查看完整参考代码和提示
- ✓ 浏览器内直接运行 Python 代码
- ✓ 自动批改 + 进度追踪
30天
¥18
1年
¥99
2年
¥158
3年
¥199