5

import numpy as np, matplotlib.pyplot as plt

from collections import Counter

data = np.random.rand(100)

train, test = data[:50], data[50:]

labels = ['Class1' if x <= 0.5 else 'Class2' for x in train]

def knn(x, k):

    d = sorted([(abs(x - xi), li) for xi, li in zip(train, labels)])

    return Counter([l for _, l in d[:k]]).most_common(1)[0][0]

k_vals = [1, 2, 3, 4, 5, 20, 30]

for k in k_vals:

    print(f"\nResults for k = {k}:")

    preds = [knn(x, k) for x in test]

    for i, p in enumerate(preds, 51):

        print(f"x{i} (value: {test[i-51]:.4f}) → {p}")

    plt.figure()

    plt.scatter(train, [0]*50, c=['blue' if l=='Class1' else 'red' for l in labels], label='Train')

    plt.scatter([test[i] for i in range(50) if preds[i]=='Class1'], [1]*preds.count('Class1'), c='blue', marker='x', label='Class1')

    plt.scatter([test[i] for i in range(50) if preds[i]=='Class2'], [1]*preds.count('Class2'), c='red', marker='x', label='Class2')

    plt.title(f"k = {k}"); plt.yticks([0,1], ['Train','Test'])

    plt.legend(); plt.grid(True); plt.show()

Comments

Popular posts from this blog

9

7

10