Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
outputAttr,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
planLater(child),
Expand Down Expand Up @@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
userFacingDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec(
timeMode: TimeMode,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
Expand Down Expand Up @@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec(
val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes)

val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs,
prevBatchTimestampMs, metrics)

val evalType = {
if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
Expand Down Expand Up @@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec {
Some(System.currentTimeMillis),
None,
None,
None,
userFacingDataType,
child,
isStreaming = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ case class TransformWithStateExec(
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan,
Expand Down Expand Up @@ -251,7 +252,7 @@ case class TransformWithStateExec(
case ProcessingTime =>
assert(batchTimestampMs.isDefined)
val batchTimestamp = batchTimestampMs.get
processorHandle.getExpiredTimers(batchTimestamp)
processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand All @@ -260,7 +261,13 @@ case class TransformWithStateExec(
case EventTime =>
assert(eventTimeWatermarkForEviction.isDefined)
val watermark = eventTimeWatermarkForEviction.get
processorHandle.getExpiredTimers(watermark)
// Only use the late-events watermark as the scan lower bound when a previous batch
// actually existed (prevBatchTimestampMs is set). In the very first batch the
// watermark propagation yields Some(0) for late events even though no timers have
// been processed yet, which would incorrectly skip timers registered at timestamp 0.
val prevWatermark =
if (prevBatchTimestampMs.isDefined) eventTimeWatermarkForLateEvents else None
processorHandle.getExpiredTimers(watermark, prevWatermark)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand Down Expand Up @@ -493,7 +500,7 @@ case class TransformWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics)
isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand All @@ -509,7 +516,7 @@ case class TransformWithStateExec(
initStateIterator: Iterator[InternalRow]):
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand Down Expand Up @@ -581,6 +588,7 @@ object TransformWithStateExec {
Some(System.currentTimeMillis),
None,
None,
None,
child,
isStreaming = false,
hasInitialState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging {
import StatefulProcessorHandleState._
Expand Down Expand Up @@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl(

/**
* Function to retrieve all expired registered timers for all grouping keys
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
verifyTimerOperations("get_expired_timers")
timerState.getExpiredTimers(expiryTimestampMs)
timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs)
}

/**
Expand Down Expand Up @@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
ttlStates.add(valueStateWithTTL)
TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
valueStateWithTTL
Expand Down Expand Up @@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
ttlStates.add(listStateWithTTL)
listStateWithTTL
Expand Down Expand Up @@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
valEncoder, ttlConfig, batchTimestampMs.get, metrics)
valEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
ttlStates.add(mapStateWithTTL)
mapStateWithTTL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,27 @@ class TimerStateImpl(
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
useMultipleValuesPerKey = false, isInternal = true)

private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex)

/**
* Encodes a timestamp into an UnsafeRow key for the secondary index.
* The timestamp is incremented by 1 so that the encoded key serves as an exclusive
* lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue
* (overflow guard).
*
* The returned UnsafeRow is always a fresh copy, safe to hold alongside other
* rows produced by the same projection.
*/
private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = {
if (tsMs < Long.MaxValue) {
val row = new GenericInternalRow(keySchemaForSecIndex.length)
row.setLong(0, tsMs + 1)
Some(secIndexProjection.apply(row).copy())
} else {
None
}
}

private def getGroupingKey(cfName: String): Any = {
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
if (keyOption.isEmpty) {
Expand Down Expand Up @@ -189,15 +210,22 @@ class TimerStateImpl(

/**
* Function to get all the expired registered timers for all grouping keys.
* Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or
* Perform a range scan on timestamp and will stop iterating once the key row timestamp
* exceeds the limit (as timestamp key is increasingly sorted).
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold.
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of all the registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
// this iter is increasingly sorted on timestamp
val iter = store.iterator(tsToKeyCFName)
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey)
val endKey = encodeTimestampAsKey(expiryTimestampMs)
val iter = store.rangeScan(startKey, endKey, tsToKeyCFName)

new NextIterator[(Any, Long)] {
override protected def getNext(): (Any, Long) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
Expand All @@ -45,9 +49,11 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToManyTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] {
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
prevBatchTimestampMs, metrics) with ListState[S] {

private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
Expand All @@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric])
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToOneTTLState(
stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig,
batchTimestampMs, metrics) with MapState[K, V] with Logging {
batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ trait TTLState {
// an expiration at or before this timestamp must be cleaned up.
private[sql] def batchTimestampMs: Long

// The batch timestamp from the previous micro-batch, used to derive the startKey
// for scan-based TTL eviction. Entries at or below prevBatchTimestampMs were already
// cleaned up in the previous batch.
private[sql] def prevBatchTimestampMs: Option[Long]

// The configuration for this run of the streaming query. It may change between runs
// (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and then
// resumes their query).
Expand All @@ -105,6 +110,8 @@ trait TTLState {

private final val TTL_ENCODER = new TTLEncoder(elementKeySchema)

private final val ELEMENT_KEY_PROJECTION = UnsafeProjection.create(elementKeySchema)

// Empty row used for values
private final val TTL_EMPTY_VALUE_ROW =
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
Expand Down Expand Up @@ -161,10 +168,25 @@ trait TTLState {
//
// The schema of the UnsafeRow returned by this iterator is (expirationMs, elementKey).
private[sql] def ttlEvictionIterator(): Iterator[UnsafeRow] = {
val ttlIterator = store.iterator(TTL_INDEX)
val dummyElementKey = ELEMENT_KEY_PROJECTION
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the dummyElementKey here just for using TTL_ENCODER.encodeTTLRow requires an elementKey?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct.

.apply(new GenericInternalRow(elementKeySchema.length))
val startKey = prevBatchTimestampMs.flatMap { prevTs =>
if (prevTs < Long.MaxValue) {
Some(TTL_ENCODER.encodeTTLRow(prevTs + 1, dummyElementKey).copy())
} else {
None
}
}
val endKey = if (batchTimestampMs < Long.MaxValue) {
Some(TTL_ENCODER.encodeTTLRow(batchTimestampMs + 1, dummyElementKey).copy())
} else {
None
}
val ttlIterator = store.rangeScan(startKey, endKey, TTL_INDEX)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the method comment indicating that we are using range scan now?


// Recall that the format is (expirationMs, elementKey) -> TTL_EMPTY_VALUE_ROW, so
// kv.value doesn't ever need to be used.
// Safety filter: keep only truly expired entries
ttlIterator.takeWhile { kv =>
val expirationMs = kv.key.getLong(0)
StateTTL.isExpired(expirationMs, batchTimestampMs)
Expand Down Expand Up @@ -223,12 +245,14 @@ abstract class OneToOneTTLState(
elementKeySchemaArg: StructType,
ttlConfigArg: TTLConfig,
batchTimestampMsArg: Long,
prevBatchTimestampMsArg: Option[Long],
metricsArg: Map[String, SQLMetric]) extends TTLState {
override private[sql] def stateName: String = stateNameArg
override private[sql] def store: StateStore = storeArg
override private[sql] def elementKeySchema: StructType = elementKeySchemaArg
override private[sql] def ttlConfig: TTLConfig = ttlConfigArg
override private[sql] def batchTimestampMs: Long = batchTimestampMsArg
override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg
override private[sql] def metrics: Map[String, SQLMetric] = metricsArg

/**
Expand Down Expand Up @@ -340,12 +364,14 @@ abstract class OneToManyTTLState(
elementKeySchemaArg: StructType,
ttlConfigArg: TTLConfig,
batchTimestampMsArg: Long,
prevBatchTimestampMsArg: Option[Long],
metricsArg: Map[String, SQLMetric]) extends TTLState {
override private[sql] def stateName: String = stateNameArg
override private[sql] def store: StateStore = storeArg
override private[sql] def elementKeySchema: StructType = elementKeySchemaArg
override private[sql] def ttlConfig: TTLConfig = ttlConfigArg
override private[sql] def batchTimestampMs: Long = batchTimestampMsArg
override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg
override private[sql] def metrics: Map[String, SQLMetric] = metricsArg

// Schema of the min-expiry index: elementKey -> minExpirationMs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
Expand All @@ -43,9 +47,11 @@ class ValueStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends OneToOneTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ValueState[S] {
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
prevBatchTimestampMs, metrics) with ValueState[S] {

private val stateTypesEncoder =
StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ class IncrementalExecution(
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
hasInitialState = hasInitialState
Expand All @@ -394,6 +395,7 @@ class IncrementalExecution(
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
hasInitialState = hasInitialState
Expand Down
Loading