1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
|
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.channels
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.selects.*
import org.junit.After
import org.junit.Test
import org.junit.runner.*
import org.junit.runners.*
import kotlin.random.Random
import kotlin.test.*
/**
* Tests resource transfer via channel send & receive operations, including their select versions,
* using `onUndeliveredElement` to detect lost resources and close them properly.
*/
@RunWith(Parameterized::class)
class ChannelUndeliveredElementStressTest(private val kind: TestChannelKind) : TestBase() {
companion object {
@Parameterized.Parameters(name = "{0}")
@JvmStatic
fun params(): Collection<Array<Any>> =
TestChannelKind.values()
.filter { !it.viaBroadcast }
.map { arrayOf<Any>(it) }
}
private val iterationDurationMs = 100L
private val testIterations = 20 * stressTestMultiplier // 2 sec
private val dispatcher = newFixedThreadPoolContext(2, "ChannelAtomicCancelStressTest")
private val scope = CoroutineScope(dispatcher)
private val channel = kind.create<Data> { it.failedToDeliver() }
private val senderDone = Channel<Boolean>(1)
private val receiverDone = Channel<Boolean>(1)
@Volatile
private var lastReceived = -1L
private var stoppedSender = 0L
private var stoppedReceiver = 0L
private var sentCnt = 0L // total number of send attempts
private var receivedCnt = 0L // actually received successfully
private var dupCnt = 0L // duplicates (should never happen)
private val failedToDeliverCnt = atomic(0L) // out of sent
private val modulo = 1 shl 25
private val mask = (modulo - 1).toLong()
private val sentStatus = ItemStatus() // 1 - send norm, 2 - send select, +2 - did not throw exception
private val receivedStatus = ItemStatus() // 1-6 received
private val failedStatus = ItemStatus() // 1 - failed
lateinit var sender: Job
lateinit var receiver: Job
@After
fun tearDown() {
dispatcher.close()
}
private inline fun cancellable(done: Channel<Boolean>, block: () -> Unit) {
try {
block()
} finally {
if (!done.trySend(true).isSuccess)
error(IllegalStateException("failed to offer to done channel"))
}
}
@Test
fun testAtomicCancelStress() = runBlocking {
println("=== ChannelAtomicCancelStressTest $kind")
var nextIterationTime = System.currentTimeMillis() + iterationDurationMs
var iteration = 0
launchSender()
launchReceiver()
while (!hasError()) {
if (System.currentTimeMillis() >= nextIterationTime) {
nextIterationTime += iterationDurationMs
iteration++
verify(iteration)
if (iteration % 10 == 0) printProgressSummary(iteration)
if (iteration >= testIterations) break
launchSender()
launchReceiver()
}
when (Random.nextInt(3)) {
0 -> { // cancel & restart sender
stopSender()
launchSender()
}
1 -> { // cancel & restart receiver
stopReceiver()
launchReceiver()
}
2 -> yield() // just yield (burn a little time)
}
}
}
private suspend fun verify(iteration: Int) {
stopSender()
drainReceiver()
stopReceiver()
try {
assertEquals(0, dupCnt)
assertEquals(sentCnt - failedToDeliverCnt.value, receivedCnt)
} catch (e: Throwable) {
printProgressSummary(iteration)
printErrorDetails()
throw e
}
sentStatus.clear()
receivedStatus.clear()
failedStatus.clear()
}
private fun printProgressSummary(iteration: Int) {
println("--- ChannelAtomicCancelStressTest $kind -- $iteration of $testIterations")
println(" Sent $sentCnt times to channel")
println(" Received $receivedCnt times from channel")
println(" Failed to deliver ${failedToDeliverCnt.value} times")
println(" Stopped sender $stoppedSender times")
println(" Stopped receiver $stoppedReceiver times")
println(" Duplicated $dupCnt deliveries")
}
private fun printErrorDetails() {
val min = minOf(sentStatus.min, receivedStatus.min, failedStatus.min)
val max = maxOf(sentStatus.max, receivedStatus.max, failedStatus.max)
for (x in min..max) {
val sentCnt = if (sentStatus[x] != 0) 1 else 0
val receivedCnt = if (receivedStatus[x] != 0) 1 else 0
val failedToDeliverCnt = failedStatus[x]
if (sentCnt - failedToDeliverCnt != receivedCnt) {
println("!!! Error for value $x: " +
"sentStatus=${sentStatus[x]}, " +
"receivedStatus=${receivedStatus[x]}, " +
"failedStatus=${failedStatus[x]}"
)
}
}
}
private fun launchSender() {
sender = scope.launch(start = CoroutineStart.ATOMIC) {
cancellable(senderDone) {
var counter = 0
while (true) {
val trySendData = Data(sentCnt++)
val sendMode = Random.nextInt(2) + 1
sentStatus[trySendData.x] = sendMode
when (sendMode) {
1 -> channel.send(trySendData)
2 -> select<Unit> { channel.onSend(trySendData) {} }
else -> error("cannot happen")
}
sentStatus[trySendData.x] = sendMode + 2
when {
// must artificially slow down LINKED_LIST sender to avoid overwhelming receiver and going OOM
kind == TestChannelKind.LINKED_LIST -> while (sentCnt > lastReceived + 100) yield()
// yield periodically to check cancellation on conflated channels
kind.isConflated -> if (counter++ % 100 == 0) yield()
}
}
}
}
}
private suspend fun stopSender() {
stoppedSender++
sender.cancel()
senderDone.receive()
}
private fun launchReceiver() {
receiver = scope.launch(start = CoroutineStart.ATOMIC) {
cancellable(receiverDone) {
while (true) {
val receiveMode = Random.nextInt(6) + 1
val receivedData = when (receiveMode) {
1 -> channel.receive()
2 -> select { channel.onReceive { it } }
3 -> channel.receiveCatching().getOrElse { error("Should not be closed") }
4 -> select { channel.onReceiveCatching { it.getOrElse { error("Should not be closed") } } }
5 -> channel.receiveCatching().getOrThrow()
6 -> {
val iterator = channel.iterator()
check(iterator.hasNext()) { "Should not be closed" }
iterator.next()
}
else -> error("cannot happen")
}
receivedCnt++
val received = receivedData.x
if (received <= lastReceived)
dupCnt++
lastReceived = received
receivedStatus[received] = receiveMode
}
}
}
}
private suspend fun drainReceiver() {
while (!channel.isEmpty) yield() // burn time until receiver gets it all
}
private suspend fun stopReceiver() {
stoppedReceiver++
receiver.cancel()
receiverDone.receive()
}
private inner class Data(val x: Long) {
private val failedToDeliver = atomic(false)
fun failedToDeliver() {
check(failedToDeliver.compareAndSet(false, true)) { "onUndeliveredElement notified twice" }
failedToDeliverCnt.incrementAndGet()
failedStatus[x] = 1
}
}
inner class ItemStatus {
private val a = ByteArray(modulo)
private val _min = atomic(Long.MAX_VALUE)
private val _max = atomic(-1L)
val min: Long get() = _min.value
val max: Long get() = _max.value
operator fun set(x: Long, value: Int) {
a[(x and mask).toInt()] = value.toByte()
_min.update { y -> minOf(x, y) }
_max.update { y -> maxOf(x, y) }
}
operator fun get(x: Long): Int = a[(x and mask).toInt()].toInt()
fun clear() {
if (_max.value < 0) return
for (x in _min.value.._max.value) a[(x and mask).toInt()] = 0
_min.value = Long.MAX_VALUE
_max.value = -1L
}
}
}
|