aboutsummaryrefslogtreecommitdiffstats
path: root/kotlinx-coroutines-core/jvm/test/flow/FlatMapStressTest.kt
blob: 699d9c6473cdacb5b5c6d59a05b954fe1ca2cd1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
 * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
 */

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.*
import kotlinx.coroutines.scheduling.*
import org.junit.Assume.*
import org.junit.Test
import java.util.concurrent.atomic.*
import kotlin.test.*

class FlatMapStressTest : TestBase() {

    private val iterations = 2000 * stressTestMultiplier
    private val expectedSum = iterations.toLong() * (iterations + 1) / 2

    @Test
    fun testConcurrencyLevel() = runTest {
        withContext(Dispatchers.Default) {
            testConcurrencyLevel(2)
        }
    }

    @Test
    fun testConcurrencyLevel2() = runTest {
        withContext(Dispatchers.Default) {
            testConcurrencyLevel(3)
        }
    }

    @Test
    fun testBufferSize() = runTest {
        val bufferSize = 5
        withContext(Dispatchers.Default) {
            val inFlightElements = AtomicLong(0L)
            var result = 0L
            (1..iterations step 4).asFlow().flatMapMerge { value ->
                unsafeFlow {
                    repeat(4) {
                        emit(value + it)
                        inFlightElements.incrementAndGet()
                    }
                }
            }.buffer(bufferSize).collect { value ->
                val inFlight = inFlightElements.get()
                assertTrue(inFlight <= bufferSize + 1,
                    "Expected less in flight elements than ${bufferSize + 1}, but had $inFlight")
                inFlightElements.decrementAndGet()
                result += value
            }

            assertEquals(0, inFlightElements.get())
            assertEquals(expectedSum, result)
        }
    }

    @Test
    fun testDelivery() = runTest {
        withContext(Dispatchers.Default) {
            val result = (1L..iterations step 4).asFlow().flatMapMerge { value ->
                unsafeFlow {
                    repeat(4) { emit(value + it) }
                }
            }.longSum()
            assertEquals(expectedSum, result)
        }
    }

    @Test
    fun testIndependentShortBursts() = runTest {
        withContext(Dispatchers.Default) {
            repeat(iterations) {
                val result = (1L..4L).asFlow().flatMapMerge { value ->
                    unsafeFlow {
                        emit(value)
                        emit(value)
                    }
                }.longSum()
                assertEquals(20, result)
            }
        }
    }

    private suspend fun testConcurrencyLevel(maxConcurrency: Int) {
        assumeTrue(maxConcurrency <= CORE_POOL_SIZE)
        val concurrency = AtomicLong()
        val result = (1L..iterations).asFlow().flatMapMerge(concurrency = maxConcurrency) { value ->
            unsafeFlow {
                val current = concurrency.incrementAndGet()
                assertTrue(current in 1..maxConcurrency)
                emit(value)
                concurrency.decrementAndGet()
            }
        }.longSum()

        assertEquals(0, concurrency.get())
        assertEquals(expectedSum, result)
    }
}