Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
outputAttr,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
planLater(child),
Expand Down Expand Up @@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
userFacingDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec(
timeMode: TimeMode,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
Expand Down Expand Up @@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec(
val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes)

val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs,
prevBatchTimestampMs, metrics)

val evalType = {
if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
Expand Down Expand Up @@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec {
Some(System.currentTimeMillis),
None,
None,
None,
userFacingDataType,
child,
isStreaming = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ case class TransformWithStateExec(
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan,
Expand Down Expand Up @@ -251,7 +252,7 @@ case class TransformWithStateExec(
case ProcessingTime =>
assert(batchTimestampMs.isDefined)
val batchTimestamp = batchTimestampMs.get
processorHandle.getExpiredTimers(batchTimestamp)
processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand All @@ -260,7 +261,13 @@ case class TransformWithStateExec(
case EventTime =>
assert(eventTimeWatermarkForEviction.isDefined)
val watermark = eventTimeWatermarkForEviction.get
processorHandle.getExpiredTimers(watermark)
// Only use the late-events watermark as the scan lower bound when a previous batch
// actually existed (prevBatchTimestampMs is set). In the very first batch the
// watermark propagation yields Some(0) for late events even though no timers have
// been processed yet, which would incorrectly skip timers registered at timestamp 0.
val prevWatermark =
if (prevBatchTimestampMs.isDefined) eventTimeWatermarkForLateEvents else None
processorHandle.getExpiredTimers(watermark, prevWatermark)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand Down Expand Up @@ -493,7 +500,7 @@ case class TransformWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics)
isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand All @@ -509,7 +516,7 @@ case class TransformWithStateExec(
initStateIterator: Iterator[InternalRow]):
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand Down Expand Up @@ -581,6 +588,7 @@ object TransformWithStateExec {
Some(System.currentTimeMillis),
None,
None,
None,
child,
isStreaming = false,
hasInitialState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging {
import StatefulProcessorHandleState._
Expand Down Expand Up @@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl(

/**
* Function to retrieve all expired registered timers for all grouping keys
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
verifyTimerOperations("get_expired_timers")
timerState.getExpiredTimers(expiryTimestampMs)
timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs)
}

/**
Expand Down Expand Up @@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
ttlStates.add(valueStateWithTTL)
TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
valueStateWithTTL
Expand Down Expand Up @@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
ttlStates.add(listStateWithTTL)
listStateWithTTL
Expand Down Expand Up @@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
valEncoder, ttlConfig, batchTimestampMs.get, metrics)
valEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
ttlStates.add(mapStateWithTTL)
mapStateWithTTL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,27 @@ class TimerStateImpl(
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
useMultipleValuesPerKey = false, isInternal = true)

private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex)

/**
* Encodes a timestamp into an UnsafeRow key for the secondary index.
* The timestamp is incremented by 1 so that the encoded key serves as an exclusive
* lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue
* (overflow guard).
*
* The returned UnsafeRow is always a fresh copy, safe to hold alongside other
* rows produced by the same projection.
*/
private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = {
if (tsMs < Long.MaxValue) {
val row = new GenericInternalRow(keySchemaForSecIndex.length)
row.setLong(0, tsMs + 1)
Some(secIndexProjection.apply(row).copy())
} else {
None
}
}

private def getGroupingKey(cfName: String): Any = {
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
if (keyOption.isEmpty) {
Expand Down Expand Up @@ -189,15 +210,22 @@ class TimerStateImpl(

/**
* Function to get all the expired registered timers for all grouping keys.
* Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or
* Perform a range scan on timestamp and will stop iterating once the key row timestamp
* exceeds the limit (as timestamp key is increasingly sorted).
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold.
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of all the registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
// this iter is increasingly sorted on timestamp
val iter = store.iterator(tsToKeyCFName)
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey)
val endKey = encodeTimestampAsKey(expiryTimestampMs)
val iter = store.rangeScan(startKey, endKey, tsToKeyCFName)

new NextIterator[(Any, Long)] {
override protected def getNext(): (Any, Long) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
Expand All @@ -45,9 +49,11 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToManyTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] {
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
prevBatchTimestampMs, metrics) with ListState[S] {

private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
Expand All @@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric])
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToOneTTLState(
stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig,
batchTimestampMs, metrics) with MapState[K, V] with Logging {
batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ trait TTLState {
// an expiration at or before this timestamp must be cleaned up.
private[sql] def batchTimestampMs: Long

// The batch timestamp from the previous micro-batch, used to derive the startKey
// for scan-based TTL eviction. Entries at or below prevBatchTimestampMs were already
// cleaned up in the previous batch.
private[sql] def prevBatchTimestampMs: Option[Long]

// The configuration for this run of the streaming query. It may change between runs
// (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and then
// resumes their query).
Expand All @@ -105,6 +110,8 @@ trait TTLState {

private final val TTL_ENCODER = new TTLEncoder(elementKeySchema)

private final val ELEMENT_KEY_PROJECTION = UnsafeProjection.create(elementKeySchema)

// Empty row used for values
private final val TTL_EMPTY_VALUE_ROW =
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
Expand Down Expand Up @@ -161,10 +168,25 @@ trait TTLState {
//
// The schema of the UnsafeRow returned by this iterator is (expirationMs, elementKey).
private[sql] def ttlEvictionIterator(): Iterator[UnsafeRow] = {
val ttlIterator = store.iterator(TTL_INDEX)
val dummyElementKey = ELEMENT_KEY_PROJECTION
.apply(new GenericInternalRow(elementKeySchema.length))
val startKey = prevBatchTimestampMs.flatMap { prevTs =>
if (prevTs < Long.MaxValue) {
Some(TTL_ENCODER.encodeTTLRow(prevTs + 1, dummyElementKey).copy())
} else {
None
}
}
val endKey = if (batchTimestampMs < Long.MaxValue) {
Some(TTL_ENCODER.encodeTTLRow(batchTimestampMs + 1, dummyElementKey).copy())
} else {
None
}
val ttlIterator = store.rangeScan(startKey, endKey, TTL_INDEX)

// Recall that the format is (expirationMs, elementKey) -> TTL_EMPTY_VALUE_ROW, so
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review: maybe this code comment is still valid? If then let's just leave it as it is.

// kv.value doesn't ever need to be used.
// Safety filter: keep only truly expired entries
ttlIterator.takeWhile { kv =>
val expirationMs = kv.key.getLong(0)
StateTTL.isExpired(expirationMs, batchTimestampMs)
Expand Down Expand Up @@ -223,12 +245,14 @@ abstract class OneToOneTTLState(
elementKeySchemaArg: StructType,
ttlConfigArg: TTLConfig,
batchTimestampMsArg: Long,
prevBatchTimestampMsArg: Option[Long],
metricsArg: Map[String, SQLMetric]) extends TTLState {
override private[sql] def stateName: String = stateNameArg
override private[sql] def store: StateStore = storeArg
override private[sql] def elementKeySchema: StructType = elementKeySchemaArg
override private[sql] def ttlConfig: TTLConfig = ttlConfigArg
override private[sql] def batchTimestampMs: Long = batchTimestampMsArg
override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg
override private[sql] def metrics: Map[String, SQLMetric] = metricsArg

/**
Expand Down Expand Up @@ -340,12 +364,14 @@ abstract class OneToManyTTLState(
elementKeySchemaArg: StructType,
ttlConfigArg: TTLConfig,
batchTimestampMsArg: Long,
prevBatchTimestampMsArg: Option[Long],
metricsArg: Map[String, SQLMetric]) extends TTLState {
override private[sql] def stateName: String = stateNameArg
override private[sql] def store: StateStore = storeArg
override private[sql] def elementKeySchema: StructType = elementKeySchemaArg
override private[sql] def ttlConfig: TTLConfig = ttlConfigArg
override private[sql] def batchTimestampMs: Long = batchTimestampMsArg
override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg
override private[sql] def metrics: Map[String, SQLMetric] = metricsArg

// Schema of the min-expiry index: elementKey -> minExpirationMs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
Expand All @@ -43,9 +47,11 @@ class ValueStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends OneToOneTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ValueState[S] {
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
prevBatchTimestampMs, metrics) with ValueState[S] {

private val stateTypesEncoder =
StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ class IncrementalExecution(
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
hasInitialState = hasInitialState
Expand All @@ -394,6 +395,7 @@ class IncrementalExecution(
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
hasInitialState = hasInitialState
Expand Down
Loading
Loading