Java多线程高并发系列:(七)手写多线程工具类

2024-10-31 12:59
310
0

一、用AQS手写信号量

package demo.aqs;

import java.util.concurrent.locks.AbstractQueuedSynchronizer;

/**
 * 自定义的信号量类,基于AQS实现。
 */
public class MySemaphore {
    private final Sync sync;

    /**
     * 构造函数,初始化信号量的许可数量。
     *
     * @param permits 初始的许可数量,必须大于等于0
     */
    public MySemaphore(int permits) {
        if (permits < 0) throw new IllegalArgumentException("许可数量不能小于0");
        this.sync = new Sync(permits);
    }

    /**
     * 获取一个许可。如果没有可用的许可,线程将被阻塞,直到有许可可用。
     *
     * @throws InterruptedException 如果线程在等待许可时被中断
     */
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * 释放一个许可,增加可用的许可数量。
     */
    public void release() {
        sync.releaseShared(1);
    }

    /**
     * 返回当前可用的许可数量。
     *
     * @return 当前可用的许可数量
     */
    public int availablePermits() {
        return sync.getPermits();
    }

    /**
     * 内部同步器类,继承自 AbstractQueuedSynchronizer。
     */
    private final class Sync extends AbstractQueuedSynchronizer {
        /**
         * 构造函数,设置初始的许可数量。
         *
         * @param permits 初始的许可数量
         */
        Sync(int permits) {
            setState(permits);
        }

        /**
         * 返回当前的许可数量。
         *
         * @return 当前的许可数量
         */
        int getPermits() {
            return getState();
        }

        /**
         * 尝试获取一个许可。
         *
         * @param acquires 需要获取的许可数量
         * @return 如果成功获取许可,返回非负值;否则返回负值
         */
        protected int tryAcquireShared(int acquires) {
            for (;;) {
                int available = getState(); // 当前可用的许可数量
                int remaining = available - acquires; // 剩余的许可数量
                if (remaining < 0 || compareAndSetState(available, remaining)) {
                    return remaining;
                }
            }
        }

        /**
         * 尝试释放一个许可。
         *
         * @param releases 需要释放的许可数量
         * @return 如果成功释放许可,返回 true;否则返回 false
         */
        protected boolean tryReleaseShared(int releases) {
            for (;;) {
                int current = getState(); // 当前的许可数量
                int next = current + releases; // 新的许可数量
                if (next < current) { // 溢出检查
                    throw new Error("最大许可数量超出限制");
                }
                if (compareAndSetState(current, next)) {
                    return true;
                }
            }
        }
    }
    
    public static void main(String[] args) {
        MySemaphore semaphore = new MySemaphore(1); // 初始化3个许可

        Runnable task = () -> {
            try {
                semaphore.acquire();
                System.out.println(Thread.currentThread().getName() + "获得一个许可,数量:"+ semaphore.availablePermits());
                Thread.sleep(1000); // 模拟任务执行时间
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                semaphore.release();
                System.out.println(Thread.currentThread().getName() + "释放一个许可,数量:"+ semaphore.availablePermits());
            }
        };

        for (int i = 0; i < 10; i++) {
            new Thread(task).start();
        }
    }
}

二、用AQS手写CyclicBarrier

package demo.aqs;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;

/**
 * 基于 AQS 实现的 CyclicBarrier。
 */
public class MyCyclicBarrier {

    /**
     * 同步器,继承自 AbstractQueuedSynchronizer。
     */
    private static class Sync extends AbstractQueuedSynchronizer {
        private final int count; // 需要等待的线程数

        Sync(int count) {
            this.count = count;
            setState(count); // 初始化状态为需要等待的线程数
        }

        /**
         * 尝试获取共享资源的方法
         * 该方法主要用于处理共享锁的获取逻辑,在同步器中用于处理读写锁的读锁获取
         *
         * @param acquires 表示要获取的资源数量,在此场景下通常为1,代表一个读锁
         * @return 返回值表示获取资源的结果如果返回值大于0,表示成功获取读锁并且之后没有线程需要等待;
         *         如果返回值等于0,表示成功获取读锁但之后可能有其他线程需要等待;
         *         如果返回值小于0,表示获取读锁失败
         */
        @Override
        protected int tryAcquireShared(int acquires) {
            for (;;) {
                int readers = getState();
                if (readers == 0) { // 如果状态为0,表示所有线程已经到达屏障
                    return -1; // 返回-1表示失败
                }
                int nextReaders = readers - acquires;
                if (compareAndSetState(readers, nextReaders)) { // CAS操作,减少状态值
                    return nextReaders == 0 ? 1 : 0; // 如果状态值减到0,返回1表示成功,否则返回0
                }
            }
        }

        /**
         * 尝试释放共享锁
         * 此方法主要用于处理共享锁的释放逻辑,特别是在同步组件(如CountDownLatch)中
         * @param releases 这里应为1,因为只有一个线程会执行释放操作,但释放的数量可能会影响状态值
         * @return 如果释放后状态值为0,表示所有线程已经到达同步屏障,则返回true;否则返回false
         */
        @Override
        protected boolean tryReleaseShared(int releases) {
            return getState() == 0; // 返回false表示失败
        }

        void reset() {
            setState(count); // 重置状态为需要等待的线程数
        }
    }

    private final Sync sync;

    /**
     * 构造函数,指定需要等待的线程数。
     *
     * @param parties 需要等待的线程数
     */
    public MyCyclicBarrier(int parties) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.sync = new Sync(parties);
    }

    /**
     * 等待所有线程到达屏障。
     *
     * @throws InterruptedException 如果当前线程被中断
     */
    public void await() throws InterruptedException {
        if (Thread.interrupted()) throw new InterruptedException();
        int acquireResult = sync.tryAcquireShared(1);
        if (acquireResult == 0) {
            synchronized (this) {
                wait(); // 使用对象的监视器进行等待
            }
        } else if (acquireResult > 0 && sync.tryReleaseShared(1)) { // 尝试释放共享锁
            synchronized (this) {
                notifyAll(); // 使用对象的监视器进行通知
                System.out.println("所有线程都到达了!");
                sync.reset(); // 重置同步器状态
            }
        }
        if (acquireResult < 0) { // 尝试获取共享锁
            throw new InterruptedException();
        }
    }

    public static void main(String[] args) {
        int numberOfParties = 3; // 设置需要等待的线程数
        MyCyclicBarrier barrier = new MyCyclicBarrier(numberOfParties);
        ExecutorService executor = Executors.newFixedThreadPool(numberOfParties);

        for (int i = 0; i < numberOfParties; i++) {
            executor.submit(() -> {
                try {
                    System.out.println(Thread.currentThread().getName() + "开始等待其它线程");
                    barrier.await();
                    System.out.println(Thread.currentThread().getName() + "继续执行");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });
        }

        executor.shutdown();
    }
}

三、用AQS手写CountDownLatch

package demo.aqs;

import java.util.concurrent.locks.AbstractQueuedSynchronizer;

/**
 * 基于 AQS 实现的 MyCountDownLatch 类。
 * MyCountDownLatch 允许一个或多个线程等待其他线程完成操作。
 */
public class MyCountDownLatch {

    private final Sync sync;

    /**
     * 构造函数,初始化计数器。
     * @param count 初始计数器值,必须大于 0
     */
    public MyCountDownLatch(int count) {
        if (count <= 0) {
            throw new IllegalArgumentException("Count must be positive");
        }
        this.sync = new Sync(count);
    }

    /**
     * 等待直到计数器归零。
     * 如果计数器已经为零,则此方法立即返回。
     * 否则,当前线程将被阻塞,直到计数器归零。
     * @throws InterruptedException 如果当前线程在等待过程中被中断
     */
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * 减少计数器的值。
     * 如果计数器归零,则所有等待的线程将被释放。
     */
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     * 获取当前计数器的值。
     * @return 当前计数器的值
     */
    public long getCount() {
        return sync.getCount();
    }

    /**
     * 内部同步器类,继承自 AbstractQueuedSynchronizer。
     */
    private static final class Sync extends AbstractQueuedSynchronizer {

        /**
         * 构造函数,初始化计数器。
         * @param count 初始计数器值
         */
        Sync(int count) {
            setState(count); // 设置初始计数器值
        }

        /**
         * 尝试获取共享锁。
         * 如果当前计数器值大于 0,则返回负值表示获取失败。
         * 如果当前计数器值为 0,则返回 0 表示获取成功。
         * @param acquires 传入的参数,这里没有实际用途
         * @return 获取结果
         */
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        /**
         * 尝试释放共享锁。
         * 减少计数器的值,如果计数器归零,则返回 true 表示释放成功。
         * @param releases 传入的参数,这里没有实际用途
         * @return 释放结果
         */
        protected boolean tryReleaseShared(int releases) {
            // 计算新的计数器值
            for (;;) {
                int current = getState();
                int next = current - 1;
                if (next < 0) // 计数器值不能小于 0
                    return false;
                if (compareAndSetState(current, next)) // CAS 操作更新计数器值
                    return next == 0; // 如果计数器归零,返回 true
            }
        }

        /**
         * 获取当前计数器的值。
         * @return 当前计数器的值
         */
        int getCount() {
            return getState();
        }
    }

    /**
     * 主函数,演示如何使用 MyCountDownLatch。
     */
    public static void main(String[] args) {
        final int numThreads = 3;
        MyCountDownLatch latch = new MyCountDownLatch(numThreads);

        // 创建并启动多个线程
        for (int i = 0; i < numThreads; i++) {
            new Thread(() -> {
                System.out.println(Thread.currentThread().getName() + " is working...");
                try {
                    Thread.sleep((long) (Math.random() * 1000)); // 模拟工作时间
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + " has finished work.");
                latch.countDown(); // 工作完成后减少计数器
            }).start();
        }

        // 主线程等待所有子线程完成工作
        try {
            latch.await();
            System.out.println("All threads have completed their work.");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

 

全部评论