只看原创   查看文章

数据结构学习(一)—— 线性表

package per.neal.chapter01;

/**
 * 线性表规范接口
 *
 * @author neal
 */
public interface LinearTable<T> {

    /**
     * 是否为空表
     *
     * @return boolean
     */
    boolean isEmpty();

    /**
     * 清空整个表
     */
    void clear();

    /**
     * 获取索引元素
     *
     * @param i 索引
     * @return 元素
     */
    T getElement(int i);

    /**
     * 获取线性表长度
     *
     * @return int
     */
    int length();

    /**
     * 在指定位置插入数据
     *
     * @param i       索引
     * @param element 元素
     */
    void insert(int i, T element);

    /**
     * 删除指定位置元素,并返回该元素
     *
     * @param i 索引
     * @return 元素
     */
    T delete(int i);

    /**
     * 更新指定项的数据元素
     *
     * @param i       索引
     * @param element 元素
     */
    void setElement(int i, T element);

    /**
     * 查找判断是否还有该元素
     *
     * @param element 元素
     * @return boolean
     */
    T haveElement(T element);

    /**
     * 在线性表末尾添加元素
     *
     * @param element 元素
     */
    void add(T element);
}


package per.neal.chapter01;


/**
 * 线性表
 *
 * @author neal
 */
public class LinearTableImpl<T> implements LinearTable<T> {
    /**
     * 表容器
     */
    private T[] data;

    /**
     * 线性表长度
     */
    private int length;

    /**
     * 指定长度的线性表
     *
     * @param size 容量
     */
    @SuppressWarnings("unchecked")
    public LinearTableImpl(int size) {
        this.data = (T[]) new Object[size];
        this.length = 0;
    }

    /**
     * 默认构造的16长度的线性表
     */
    public LinearTableImpl() {
        this(16);
    }

    @Override
    public boolean isEmpty() {
        return this.length == 0;
    }

    @Override
    public void clear() {
        for (int i = 0; i < this.length; i++) {
            this.data[i] = null;
        }
        this.length = 0;
    }

    @Override
    public T getElement(int i) {
        if (i < 0 || i >= length) {
            throw new IndexOutOfBoundsException(i + "越界了");
        }
        return this.data[i];
    }

    @Override
    public int length() {
        return this.length;
    }

    @Override
    @SuppressWarnings("unchecked")
    public void insert(int i, T element) {
        if (i > this.length || i < 0) {
            throw new IndexOutOfBoundsException(i + "越界了");
        }
        if (element == null) {
            return;
        }
        if (this.length == this.data.length) {
            T[] elements = this.data;
            this.data = (T[]) new Object[elements.length * 3 >>> 1];
            System.arraycopy(elements, 0, this.data, 0, elements.length);
        }
        System.arraycopy(this.data, i, this.data, i + 1, this.length - i);
        this.data[i] = element;
        this.length++;
    }

    @Override
    public T delete(int i) {
        if (i < 0 || i > this.length) {
            throw new IndexOutOfBoundsException(i + "越界了");
        }
        T element = this.data[i];
        System.arraycopy(this.data, i + 1, this.data, i, this.length - 1 - i);
        this.length--;
        return element;
    }

    @Override
    public void setElement(int i, T element) {
        if (i < 0 || i > this.length) {
            throw new IndexOutOfBoundsException(i + "越界了");
        }
        if (element == null) {
            return;
        }
        this.data[i] = element;
    }

    @Override
    public T haveElement(T element) {
        int find = this.indexOf(element);
        return find == -1 ? null : this.data[find];
    }

    @Override
    public void add(T element) {
        this.insert(this.length, element);
    }

    private int indexOf(T element) {
        if (element != null) {
            for (int i = 0; i < this.length; i++) {
                if (this.data[i].equals(element)) {
                    return i;
                }
            }
        }
        return -1;
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder("(");
        for (T element : this.data) {
            builder.append(element).append(",");
        }
        builder.append(")");
        return builder.toString();
    }
}


原创 Mar 4, 2018 9:12:36 AM 60 0

数据结构学习(二)—— 单链表

/**
 * 单链表
 *
 * @author Neal
 */
public class LinkedListCus<T> {
    private class Node {

        /**
         * 数据域
         */
        private T item;
        /**
         * 指针域
         */
        private Node next;

        Node(T item, Node next) {
            this.item = item;
            this.next = next;
        }
    }

    /**
     * 链表头引用
     */
    private Node head;
    /**
     * 链表尾引用
     */
    private Node tail;
    /**
     * 链表长度
     */
    private int size;

    public LinkedListCus() {
        this.head = null;
        this.tail = null;
    }

    public LinkedListCus(T data) {
        // 指定一个头节点的数据域为data,不指向其他节点
        this.head = new Node(data, null);
        tail = head;
        size++;
    }

    public int length() {
        return size;
    }

    public T getElement(int index) {
        return findNodeByIndex(index).item;
    }

    private Node findNodeByIndex(int index) {
        if (index < 0 || index > size - 1) {
            throw new IndexOutOfBoundsException("线性表越界");
        }
        Node current = head;
        for (int i = 0; i < size && current.next != null;
             i++, current = current.next) {
            if (i == index) {
                return current;
            }
        }
        return null;
    }

    public int getIndexByElement(T elements) {
        // 从第一个节点开始遍历
        Node current = head;
        for (int i = 0; i < size && current.next != null;
             i++, current = current.next) {
            if (current.item.equals(elements)) {
                return i;
            }
        }
        return -1;
    }

    public void insert(int index, T element) {
        if (index < 0 || index > size) {
            throw new IndexOutOfBoundsException("线性表越界");
        }
        if (head == null) {
            add(element);
        } else {
            if (index == 0) {
                addAtHead(element);
            } else {
                // 找到要插入位置的前一个节点
                Node prev = findNodeByIndex(index - 1);
                // 插入后prev的next指向新节点,
                // 新节点的next指向原来prev的下一节点
                prev.next = new Node(element, prev.next);
            }
        }
        size++;
    }

    /**
     * 每次在尾部添加新节点
     *
     * @param element element
     */
    public void add(T element) {
        if (head == null) {
            head = new Node(element, null);
            tail = head;
        } else {
            Node node = new Node(element, null);
            tail.next = node;
            tail = node;
        }
        size++;
    }

    /**
     * 在头部插入节点
     *
     * @param element element
     */
    public void addAtHead(T element) {
        // 在头部插入新节点,就是让新节点的next指向原来的head,
        // 让新节点作为链表的头结点
        head = new Node(element, head);
        // 如果插入之前是空链表
        if (tail == null) {
            tail = head;
        }
    }

    public T delete(int index) {
        Node deleteNode;
        if (index < 0 || index > size - 1) {
            throw new IndexOutOfBoundsException("线性表越界");
        }
        if (index == 0) {
            deleteNode = head;
            head = head.next;
        } else {
            // 获取要删除的节点的前一个节点
            Node prev = findNodeByIndex(index - 1);
            // 要删除的节点就是prev的next指向的节点
            deleteNode = prev.next;
            // 删除以后prev的next指向被删除之前所指向的next
            prev.next = deleteNode.next;
            deleteNode.next = null;
        }
        size--;
        return deleteNode.item;
    }

    /**
     * 删除 链表的最后一个元素
     *
     * @return T
     */
    public T removeLast() {
        return delete(size - 1);
    }

    public void clear() {
        head = null;
        tail = null;
        size = 0;
    }

    public boolean isEmpty() {
        return size == 0;
    }

    /**
     * 重写单链表的toString
     *
     * @return 链表元素列表
     */
    @Override
    public String toString() {
        if (isEmpty()) {
            return "[]";
        } else {
            StringBuilder sb = new StringBuilder("[");
            for (Node current = head; current != null; current = current.next) {
                sb.append(current.item.toString()).append(",");
            }
            int len = sb.length();
            return sb.delete(len - 1, len).append("]").toString();
        }
    }
}


原创 Mar 19, 2018 8:17:20 PM 49 0

数据结构学习(三)—— 双链表(双向链表)

/**
 * 双链表
 *
 * @author Neal
 */
public class LinkedListBySelf<T> {
    private class Node {
        private T data;
        private Node next;
        private Node prev;

        Node(T data, Node prev, Node next) {
            this.data = data;
            this.prev = prev;
            this.next = next;
        }
    }

    private Node head;
    private Node tail;
    private int size;

    public LinkedListBySelf() {
        this.head = null;
        this.tail = null;
        this.size = 0;
    }

    public void add(T element) {
        if (head == null) {
            head = new Node(element, null, null);
            tail = head;
        } else {
            // 新节点
            Node last = new Node(element, tail, null);
            // 尾部的下一个节点是新节点
            tail.next = last;
            // 将尾部指针到新节点
            tail = last;
        }
        size++;
    }

    public int length() {
        return size;
    }

    public T remove() {
        if (head == tail) {
            throw new IndexOutOfBoundsException("这是个空链表,无法删除");
        } else {
            return removeByIndex(size);
        }
    }

    public T removeByIndex(int index) {
        if (index > size || index < 0) {
            throw new IndexOutOfBoundsException("越界了");
        } else {
            Node lastPrev = loopThis(index - 1);
            Node delNode = lastPrev.next;
            lastPrev.prev = delNode.prev;
            lastPrev.next = delNode.next;

            delNode.prev = null;
            delNode.next = null;
            size--;
            return delNode.data;
        }
    }

    private Node loopThis(int index) {
        if (head == tail || size == 0) {
            throw new IndexOutOfBoundsException("空链表");
        } else {
            Node current = tail;
            for (int i = size; i >= 0 && current.prev != null; current = current.prev, i--) {
                if (i == index) {
                    return current;
                }
            }
        }
        return null;
    }
}


原创 Mar 19, 2018 8:18:59 PM 53 0

数据结构学习(四)—— 栈的顺序存储结构

/**
 * 线性栈(基于顺序存储结构)
 *
 * @author Neal
 */
public class OrderStack<T> {
    private Node node;

    public OrderStack(int size) {
        this.node = new Node(size);
    }

    private class Node {
        T[] data;
        /**
         * 栈顶指针
         */
        int top;

        @SuppressWarnings("unchecked")
        Node(int size) {
            this.data = (T[]) new Object[size];
            this.top = -1;
        }
    }

    public void push(T element) {
        if (node.top == node.data.length - 1) {
            throw new IndexOutOfBoundsException("栈满了");
        }
        node.data[++node.top] = element;
    }

    public T pop() {
        if (node.top == -1) {
            throw new IndexOutOfBoundsException("栈是空的");
        }
        return node.data[node.top--];
    }

    public int length() {
        return node.top;
    }
}


原创 Mar 19, 2018 8:20:44 PM 42 0

数据结构学习(五)—— 两栈共享空间栈(未测试)

/**
 * 两栈共享空间
 *
 * @author Neal
 */
public class ShareOrderStack<T> {
    private Node node;

    public ShareOrderStack(int size) {
        node = new Node(size);
    }

    private class Node {
        /**
         * 数据区
         */
        private T[] data;
        /**
         * 栈1 栈顶指针
         */
        private int top1;
        /**
         * 栈2 栈顶指针
         */
        private int top2;

        @SuppressWarnings("unchecked")
        Node(int size) {
            this.data = (T[]) new Object[size];
            this.top1 = -1;
            this.top2 = size;
        }
    }

    public void push(T element, int choose) throws Exception {
        if (node.top1 + 1 == node.top2) {
            throw new IndexOutOfBoundsException("栈已经满了");
        }
        if (choose < 0 || choose > 2) {
            throw new Exception("没有这个选择");
        }
        if (choose == 1) {
            node.data[++node.top1] = element;
        } else if (choose == 2) {
            node.data[--node.top1] = element;
        }
    }

    public T pop(int choose) throws Exception {
        if (choose < 0 || choose > 2) {
            throw new Exception("没有这个选择");
        }
        if (choose == 1) {
            if (node.top1 == -1) {
                throw new IndexOutOfBoundsException("栈1是空的");
            }
            return node.data[node.top1--];
        } else if (choose == 2) {
            if (node.top2 == node.data.length) {
                throw new IndexOutOfBoundsException("栈2是空的");
            }
            return node.data[node.top2--];
        }
        return null;
    }
}


原创 Mar 19, 2018 8:21:46 PM 39 0

数据结构学习(六)—— 链栈

/**
 * 链栈
 *
 * @author Neal
 */
public class LinkedStack<T> {
    private class Node {
        private T data;
        private Node next;

        Node(T data, Node next) {
            this.data = data;
            this.next = next;
        }
    }

    /**
     * 头指针,跟计数器
     */
    private Node head;
    private int size;

    public LinkedStack() {
        this.head = null;
        this.size = 0;
    }

    public void push(T element) {
        head = new Node(element, head);
        size++;
    }

    public int length() {
        return size;
    }

    public T remove() {
        if (size == 0) {
            throw new IndexOutOfBoundsException("空栈无法执行删除操作");
        } else {
            Node nodePre = head;
            head = head.next;
            nodePre.next = null;
            size--;
            return nodePre.data;
        }
    }
}


原创 Mar 19, 2018 8:22:42 PM 38 0

机器学习(1)-- java实现KNN

首先得弄个bean存储数据,因为Java无法直接返回两个数组

/**
 * @author Neal
 */
public class DataSetBean {
    private int[][] dataSet;
    private String[] labels;

    public int[][] getDataSet() {
        return dataSet;
    }

    public void setDataSet(int[][] dataSet) {
        this.dataSet = dataSet;
    }

    public String[] getLabels() {
        return labels;
    }

    public void setLabels(String[] labels) {
        this.labels = labels;
    }
}

再就是具体实现了,KNN的具体原理请自行百度,这里就不在详讲了。看代码

import java.util.*;

/**
 * KNN(K-近邻法) JAVA实现
 * 测试该女性(身高,体重)属于哪一身材阶层(纯属搞笑,切勿当真)
 *
 * @author Neal
 */
public class KnnCus {

    /**
     * 创建训练集数据,跟分类标签,并封装到JavaBean中
     *
     * @return DataSetBean
     */
    private DataSetBean createDataSet() {
        DataSetBean bean = new DataSetBean();
        int[][] dataSet = {
                {171, 85}, {160, 107}, {158, 106}, {158, 96}, {160, 87},
                {162, 84}, {165, 88}, {170, 96}, {167, 90}, {168, 94}
        };
        String[] labels = {"1分", "2分", "3分", "4分", "5分", "6分", "7分", "8分", "9分", "10分"};
        bean.setDataSet(dataSet);
        bean.setLabels(labels);
        return bean;
    }

    /**
     * Knn 分类器
     * 原理为:欧氏距离
     *
     * @param inX     测试集(需要测试的数据)
     * @param dataSet 训练集(初始数据)
     * @param labels  标注(标注的数据)
     * @param k       Knn算法的参数,表示,离的最近的K个元素
     * @return 该分类
     */
    private String classify(int[][] inX, int[][] dataSet, String[] labels, int k) {
        int dataSetSize = dataSet.length;
        int[][] df = new int[dataSetSize][2];
        // 将其测试数据,进行行的扩展
        for (int i = 0; i < dataSetSize; i++) {
            System.arraycopy(inX[0], 0, df[i], 0, 2);
        }
        int[][] diffMat = new int[dataSetSize][2];
        for (int i = 0; i < dataSetSize; i++) {
            for (int j = 0; j < 2; j++) {
                diffMat[i][j] = df[i][j] - dataSet[i][j];
            }
        }
        for (int i = 0; i < dataSetSize; i++) {
            for (int j = 0; j < 2; j++) {
                diffMat[i][j] *= diffMat[i][j];
            }
        }
        double[] distance = new double[dataSetSize];
        for (int i = 0; i < dataSetSize; i++) {
            double temp = diffMat[i][0] + diffMat[i][1];
            distance[i] = Math.sqrt(temp);
        }
        int[] sortedByIndex = sortedByIndex(distance);
        // 这里为了保证顺序要用linkedHashMap
        Map<String, Integer> classCount = new LinkedHashMap<>(3);
        for (int i = 0; i < k; i++) {
            String voteLabel = labels[sortedByIndex[i]];
            // 如果所得键不存在,默认给值设为1,再添加
            classCount.merge(voteLabel, 1, (a, b) -> a + b);
        }
        // 再根据值排序
        Map<String, Integer> list = sortMapByValue(classCount);
        // 返回最近的第一个元素
        return list.keySet().iterator().next();
    }

    /**
     * 按最小值排序,然后返回一列索引,
     * 原理:将原数组复制,再排序一个,然后再去对比是否相等,得到在原数组的索引位置,添加到新数组
     * 缺点:不稳定,对于重复元素,只记录后一个的索引
     *
     * @param old 需要排序的数组
     * @return 返回索引列表
     */
    private int[] sortedByIndex(double[] old) {
        double[] target = new double[old.length];
        System.arraycopy(old, 0, target, 0, old.length);
        int[] newIndex = new int[old.length];
        Arrays.sort(target);
        for (int i = 0; i < old.length; i++) {
            for (int j = 0; j < old.length; j++) {
                if (target[i] == old[j]) {
                    newIndex[i] = j;
                }
            }
        }
        return newIndex;
    }

    /**
     * Map按值排序
     *
     * @param oriMap 原本Map
     * @return Map
     */
    private Map<String, Integer> sortMapByValue(Map<String, Integer> oriMap) {
        Map<String, Integer> sortedMap = new LinkedHashMap<>();
        if (oriMap != null && !oriMap.isEmpty()) {
            List<Map.Entry<String, Integer>> entryList = new ArrayList<>(oriMap.entrySet());
            entryList.sort((o1, o2) -> {
                int value1, value2;
                try {
                    value1 = o1.getValue();
                    value2 = o2.getValue();
                } catch (NumberFormatException e) {
                    value1 = 0;
                    value2 = 0;
                }
                return value2 - value1;
            });
            Iterator<Map.Entry<String, Integer>> item = entryList.iterator();
            Map.Entry<String, Integer> tmpEntry;
            while (item.hasNext()) {
                tmpEntry = item.next();
                sortedMap.put(tmpEntry.getKey(), tmpEntry.getValue());
            }
        }
        return sortedMap;
    }

    public static void main(String[] args) {
        KnnCus knn = new KnnCus();
        DataSetBean bean = knn.createDataSet();
        // 该美女身高165,体重95
        int[][] inX = {{165, 95}};
        int k = 3;
        String result = knn.classify(inX, bean.getDataSet(), bean.getLabels(), k);
        System.out.println(result);
    }
}


原创 Mar 22, 2018 12:18:36 PM 46 0

机器学习(1)-- python实现KNN

#!/usr/bin/python3
# coding:utf-8
# Filename:test_girl.py
# Author:Neal
# Time:2018.03.21 17:09

import numpy as np
import operator

"""
测试该女性属于哪一身材阶层(纯属搞笑,切勿当真)
"""


def create_data_set():
    """
    创建数据集
    :return: 数据集跟标签集
    """
    data_set = np.array(
        [[171, 85], [160, 107], [158, 106], [158, 96], [160, 87],
         [162, 84], [165, 88], [170, 96], [167, 90], [168, 94]])
    label_set = ['1分', '2分', '3分', '4分', '5分', '6分', '7分', '8分', '9分', '10分']
    return data_set, label_set


def classify(in_x, data_set, labels, k):
    """
    KNN算法分类器
    :param in_x: 用于分类的数据(测试集)
    :param data_set: 用于训练的数据(训练集)
    :param labels: 分类标签
    :param k: KNN算法参数,选择距离最小的K个点
    :return: 分类结果
    """
    data_set_size = data_set.shape[0]
    df = np.tile(in_x, (data_set_size, 1))
    diff_mat = df - data_set
    sq_diff_mat = diff_mat ** 2
    sq_distance = sq_diff_mat.sum(axis=1)
    distances = sq_distance ** 0.5
    sorted_by_index_list = distances.argsort()
    class_count = {}
    for i in range(k):
        vote_label = labels[sorted_by_index_list[i]]
        class_count[vote_label] = class_count.get(vote_label, 0) + 1

    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]


if __name__ == '__main__':
    group, labels = create_data_set()
    test = [165, 95]
    test_class = classify(test, group, labels, 3)
    print(test_class)


原创 Mar 22, 2018 12:22:10 PM 37 0

机器学习(2)-- python决策树

#!/usr/bin/python3
# coding:utf-8
# Filename:np3.py
# Author:Neal
# Time:2018.03.20 17:11
import operator
from math import log
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt


def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],  # 数据集
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']]
    # 分类属性
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
    # 返回数据集跟分类属性
    return dataSet, labels


def calcShannonEnt(dataSet):
    """
    计算给定数据集的经验熵
    :param dataSet: 数据集
    :return: 经验熵(香农熵)
    """
    numEntires = len(dataSet)  # 返回数据集行数
    labelCounts = {}  # 保存每个标签(Label)出现次数的字典
    for featVec in dataSet:  # 对于每组特征向量进行统计
        currentLabel = featVec[-1]  # 提取标签(Label)信息
        if currentLabel not in labelCounts.keys():  # 如果标签(Label)没有放入统计次数
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1  # Label计数
    shannoEnt = 0.0  # 经验熵(香农熵)
    for key in labelCounts:  # 计算香农熵
        prob = float(labelCounts[key]) / numEntires  # 选择该标签的概率
        shannoEnt -= prob * log(prob, 2)  # 利用公式计算
    return shannoEnt  # 返回经验熵(香农熵)


def splitDataSet(dataSet, axis, value):
    """
    按照给定特征划分数据集
    :param dataSet: 数据集
    :param axis: 划分数据集的特征
    :param value: 需要返回的特征值
    :return:
    """
    retDataSet = []  # 创建返回的数据集列表
    for featVec in dataSet:  # 遍历数据集
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 去掉axis特征
            reducedFeatVec.extend(featVec[axis + 1:])  # 将符合条件的添加到返回的数据集
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 返回划分后的数据集


def chooseBestFeatureToSplit(dataSet):
    """
    选择最优特征
    :param dataSet:数据集
    :return:信息增益最大的(最优)特征的索引值
    """
    numFeatures = len(dataSet[0]) - 1  # 特征数量
    baseEntropy = calcShannonEnt(dataSet)  # 计算数据集的香农熵
    bestInfoGain = 0.0  # 信息增益
    bestFeature = -1  # 最优特征的索引值
    for i in range(numFeatures):  # 遍历所有特征
        # 获取dataSet的第i个所有特征
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 创建set集合{},元素不可重复
        newEntroy = 0.0  # 经验条件熵
        for value in uniqueVals:  # 计算信息增益
            subDataSet = splitDataSet(dataSet, i, value)  # subDataSet划分后的子集
            prob = len(subDataSet) / float(len(dataSet))  # 计算子集的概率
            newEntroy += prob * calcShannonEnt(subDataSet)  # 根据公式计算经验条件熵
        infoGain = baseEntropy - newEntroy  # 信息增益
        print("第%d个特征的增益为%.3f" % (i, infoGain))
        if infoGain > bestInfoGain:  # 计算信息增益
            bestInfoGain = infoGain  # 更新信息增益
            bestFeature = i  # 记录信息增益的最大特征的索引值
    return bestFeature  # 返回信息增益的最大特征的索引值


def majorityCnt(classList):
    """
    统计classList中出现此处最多的元素(类标签)
    :param classList:
    :return:
    """
    classCount = {}
    for vote in classList:  # 统计classList中每个元素出现的次数
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]  # 返回classList中出现次数最多的元素


def createTree(dataSet, labels, featLabels):
    """
    创建决策树
    :param dataSet: 数据集
    :param labels: 标签
    :param featLabels:
    :return:
    """
    classList = [example[-1] for example in dataSet]  # 取分类标签(是否放贷)
    if classList.count(classList[0]) == len(classList):  # 如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:  # 遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 选择最优特征
    bestFeatLabel = labels[bestFeat]  # 最优特征标签
    featLabels.append(bestFeatLabel)
    my_tree = {bestFeatLabel: {}}  # 根据最优特征的标签生成树
    del (labels[bestFeat])  # 删除已经使用的特征标签
    featValues = [example[bestFeat] for example in dataSet]  # 得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)  # 去掉重复的属性值
    for value in uniqueVals:  # 遍历特征,创建决策树
        my_tree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return my_tree


def getNumLeafs(myTree):
    """
    获取决策树叶子节点的数目
    :param myTree:决策树
    :return: 决策树的叶子节点的数目
    """
    numLeafs = 0  # 初始化叶子
    firstStr = next(iter(myTree))  # python3中的myTree.keys()返回的是dict_keys,所以不能使员工myTree.keys()[0]方法获取
    secondDict = myTree[firstStr]  # 获取下一组字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试该节点是否为字典,如果不是字典,代表此节点为叶子节点
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    获取决策树的层数
    :param myTree: 决策树
    :return: 决策树的层数
    """
    maxDepth = 0  # 初始化决策树深度
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]  # 获取下一个字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试该节点是否为字段是否为字段,如果不是字典,则代表该节点是叶子节点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth  # 更新层数
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制节点
    :param nodeTxt: 节点名
    :param centerPt: 文本位置
    :param parentPt: 标注的箭头位置
    :param nodeType: 节点格式
    :return:
    """
    arrow_args = dict(arrowstyle='<-')  # 定义箭头格式
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)  # 设置字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,
                            textcoords='axes fraction', va='center', ha='center',
                            bbox=nodeType, arrowprops=arrow_args, FontProperties=font)  # 绘制节点


def plotMidText(cntrPt, parentPt, txtString):
    """
    标注有向边属性值
    :param cntrPt:标注的位置
    :param parentPt: 标注的位置
    :param txtString: 标注的内容
    :return:
    """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 计算标注位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va='center', ha="center",
                        rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制决策树
    :param myTree: 决策树(字典)
    :param parentPt: 标注的内容
    :param nodeTxt: 节点名
    :return:
    """
    decisionNode = dict(boxstyle='sawtooth', fc="0.8")  # 设置节点格式
    leafNode = dict(boxstyle='round4', fc="0.8")  # 设置叶节点格式
    numLeafs = getNumLeafs(myTree)  # 获取决策树叶节点数目,决定了树的宽度
    depth = getTreeDepth(myTree)  # 获取决策树层度
    firstStr = next(iter(myTree))  # 下一个字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)  # 中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制节点
    secondDict = myTree[firstStr]  # 下一个字典,也就是继续绘制子节点
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))  # 不是叶节点,递归继续绘制
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW  # 是叶节点,绘制叶节点,并标注有向边属性值
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), "")
    plt.show()


def classsify(inputTree, featLabels, testVec):
    classLabel = None
    firstStr = next(iter(inputTree))
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classsify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet, labels, featLabels)
    # print(myTree)
    # createPlot(myTree)
    testVec = [0, 1]
    result = classsify(myTree, featLabels, testVec)
    if result == 'yes':
        print('放贷')
    elif result == 'no':
        print('不放贷')


原创 Mar 22, 2018 12:24:08 PM 51 0

机器学习(3) -- python朴素贝叶斯

#!/usr/bin/python3
# coding:utf-8
# Filename:np7.py
# Author:Neal
# Time:2018.03.21 15:46
import numpy as np
from functools import reduce


def loadDataSet():
    """
    创建试验样本
    :return:
    """
    postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],  # 切分的词条
                   ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
                   ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
                   ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
                   ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
                   ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
    classVec = [0, 1, 0, 1, 0, 1]
    return postingList, classVec


def setOfWords2Vec(vocabList, inputSet):
    """
    根据VocabList词汇表,将inputSet向量化,向量的每个元素为1或者0
    :param vocabList: createVocabList返回的列表
    :param inputSet: 切分的词条列表
    :return:  文档向量,词集模型
    """
    returnVec = [0] * len(vocabList)  # 创建一个其中所含元素都为0的向量
    for word in inputSet:  # 遍历每个词条
        if word in vocabList:  # 如果词条存在词汇表中,则置1
            returnVec[vocabList.index(word)] = 1
        else:
            print("the word:%s is not in my Vocabulary!" % word)
    return returnVec  # 返回文档向量


def createVocabList(dataSet):
    """
    将切分的实验样本词条整理成不重复的词条列表,也就是词汇表
    :param dataSet: 整理的样本数据集
    :return: 返回不重复的词条列表,也就是词汇表
    """
    vocabSet = set([])  # 创建一个空的不重复列表
    for document in dataSet:
        vocabSet = vocabSet | set(document)  # 取并集
    return list(vocabSet)


def trainNBO(trainMatrix, trainCategory):
    """
    朴素贝叶斯分类器训练函数
    :param trainMatrix:  训练文档矩阵,即setOfWords2Vec返回的returnVec构成的矩阵
    :param trainCategory: 训练类别标签向量,即loadDataSet返回的classVec
    :return:
    """
    numTrainDocs = len(trainMatrix)  # 计算训练的文档数目
    numWords = len(trainMatrix[0])  # 计算每篇文档的词条
    pAbusive = sum(trainCategory) / float(numTrainDocs)  # 文档属于侮辱类的概率
    p0Num = np.zeros(numWords)  # 创建numpy.zeros数组,词条出现数初始化为0
    p1Num = np.zeros(numWords)
    p0Denom = 0.0  # 分母初始化为0
    p1Denom = 0.0
    for i in range(numTrainDocs):  # 统计输入侮辱类的条件概率
        if trainCategory[i] == 1:
            p1Num += trainMatrix[i]
            p1Denom += sum(trainMatrix[i])
        else:  # 统属于非侮辱类的条件概率
            p0Num += trainMatrix[i]
            p0Denom += sum(trainMatrix[i])
    p1Vect = p1Num / p1Denom
    p0Vect = p0Num / p0Denom
    return p0Vect, p1Vect, pAbusive  # 返回属于侮辱类的条件概率


def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):
    """
    朴素贝叶斯分类器分类函数
    :param vec2Classify: 待分类的词条数组
    :param p0Vec: 侮辱类的条件概率数组
    :param p1Vec: 非侮辱类的条件概率数组
    :param pClass1: 文档属于侮辱类的概率
    :return: 0 - 非侮辱类
             1  - 侮辱类
    """
    p1 = reduce(lambda x, y: x * y, vec2Classify * p1Vec) * pClass1
    p0 = reduce(lambda x, y: x * y, vec2Classify * p0Vec) * (1.0 - pClass1)
    print('p0:', p0)
    print('p1:', p1)
    if p1 > p0:
        return 1
    else:
        return 0


def testingNB():
    listOPosts, listClasses = loadDataSet()
    myVocabList = createVocabList(listOPosts)
    trainMat = []
    for postingDoc in listOPosts:
        trainMat.append(setOfWords2Vec(myVocabList, postingDoc))
    p0V, p1V, pAb = trainNBO(np.array(trainMat), np.array(listClasses))
    testEntry = ['love', 'my', 'dalmation']
    thisDoc = np.array(setOfWords2Vec(myVocabList, testEntry))
    if classifyNB(thisDoc, p0V, p1V, pAb):
        print(testEntry, '属于侮辱类')
    else:
        print(testEntry, '属于非侮辱类')
    testEntry = ['stupid', 'garbage']

    thisDoc = np.array(setOfWords2Vec(myVocabList, testEntry))
    if classifyNB(thisDoc, p0V, p1V, pAb):
        print(testEntry, '属于侮辱类')
    else:
        print(testEntry, '属于非侮辱类')


if __name__ == '__main__':
    # postingList, classVec = loadDataSet()
    # myVocabList = createVocabList(postingList)
    # print('myVocabList:\n', myVocabList)
    # trainMat = []
    # for postingDoc in postingList:
    #     trainMat.append(setOfWords2Vec(myVocabList, postingDoc))
    # p0V, p1V, pAb = trainNBO(trainMat, classVec)
    # print('p0V:\n', p0V)
    # print('p1V:\n', p1V)
    # print("classVec:\n", classVec)
    # print("pAb:\n", pAb)
    testingNB()


转载 Mar 22, 2018 12:25:49 PM 49 0

数据结构学习(七)—— 顺序存储队列

/**
 * 循环队列(添加时,需要注意:添加元素的长度为指定size-1)
 * 会浪费一个空间
 *
 * @author Neal
 */
public class LoopQueue<T> {
    private class Node {
        private T[] data;
        private int front;
        private int rear;

        @SuppressWarnings("unchecked")
        Node(int size) {
            data = (T[]) new Object[size];
            this.front = this.rear = 0;
        }
    }

    private Node node;

    public LoopQueue(int size) {
        node = new Node(size);
    }

    public int queueLength() {
        return (node.rear - node.front + node.data.length) % node.data.length;
    }

    public void enQueue(T element) {
        if ((node.rear + 1) % node.data.length == node.front) {
            throw new IndexOutOfBoundsException("队列已经满了");
        } else {
            node.data[node.rear] = element;
            node.rear = (node.rear + 1) % node.data.length;
        }
    }

    public T deQueue() {
        if (node.front == node.rear) {
            throw new IndexOutOfBoundsException("队列为空");
        }
        T data = node.data[node.front];
        node.front = (node.front + 1) % node.data.length;
        return data;
    }
}

小心坑..............

原创 Mar 24, 2018 11:52:29 AM 57 0

数据结构学习(八)—— 链式存储队列(链表队列)

/**
 * 链表队列
 *
 * @author Neal
 */
public class LinkedQueue<T> {
    private class Node {
        private T data;
        private Node next;

        Node(T data) {
            this.data = data;
        }
    }

    private Node front;
    private Node rear;
    private Node node;

    public LinkedQueue() {
        this.front = this.rear = null;
    }

    public void enQueue(T element) {
        if (rear == null && front == null) {
            front = rear = new Node(element);
        } else {
            // 构建一个节点
            // 节点下一个指向为空
            // 将尾部节点的下一个指向指向这个节点
            rear.next = new Node(element);
            // 再讲尾部节点移动到这个节点
            rear = rear.next;
        }
    }

    public T deQueue() {
        // 如果尾部节点跟头部节点一样,则代表是个空队列
        if (front == rear) {
            throw new IndexOutOfBoundsException("这是个空队列");
        } else {
            // 获取头节点的下个节点
            T element = front.data;
            // 将头的下一个指向,改成当前要删除的节点的下一个指向
            front = front.next;
            if (rear == node) {
                // 则说明删除到底了,成了空队列
                rear = front;
            }
            return element;
        }
    }

    public int length() {
        int count = 0;
        while (front != null) {
            count++;
            front = front.next;
        }
        return count;
    }
}


原创 Mar 24, 2018 11:55:33 AM 56 0

我的头像

黑天白夜

你懂的越多,懂你的就越少!

  • 来访数:4,422
  • 总文章:28
  • 原创数:27
  • 点赞数:8