5
import numpy as np, matplotlib.pyplot as plt
from collections import Counter
data = np.random.rand(100)
train, test = data[:50], data[50:]
labels = ['Class1' if x <= 0.5 else 'Class2' for x in train]
def knn(x, k):
d = sorted([(abs(x - xi), li) for xi, li in zip(train, labels)])
return Counter([l for _, l in d[:k]]).most_common(1)[0][0]
k_vals = [1, 2, 3, 4, 5, 20, 30]
for k in k_vals:
print(f"\nResults for k = {k}:")
preds = [knn(x, k) for x in test]
for i, p in enumerate(preds, 51):
print(f"x{i} (value: {test[i-51]:.4f}) → {p}")
plt.figure()
plt.scatter(train, [0]*50, c=['blue' if l=='Class1' else 'red' for l in labels], label='Train')
plt.scatter([test[i] for i in range(50) if preds[i]=='Class1'], [1]*preds.count('Class1'), c='blue', marker='x', label='Class1')
plt.scatter([test[i] for i in range(50) if preds[i]=='Class2'], [1]*preds.count('Class2'), c='red', marker='x', label='Class2')
plt.title(f"k = {k}"); plt.yticks([0,1], ['Train','Test'])
plt.legend(); plt.grid(True); plt.show()
Comments
Post a Comment