Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice Callback Function with Batch Interval Functionality #129

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
163 changes: 160 additions & 3 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
logger: &noopLogger{
logger: log.New(ioutil.Discard, "", log.LstdFlags),
},
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
scanInterval: 250 * time.Millisecond,
maxRecords: 10000,
batchInterval: 0,
}

// override defaults
Expand Down Expand Up @@ -80,15 +81,24 @@ type Consumer struct {
scanInterval time.Duration
maxRecords int64
isAggregated bool
batchInterval int64
}

// ScanFunc is the type of the function called for each message read
// from the stream. The record argument contains the original record
// returned from the AWS Kinesis library.
// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value ErrSkipCheckpoint.

type ScanFunc func(*Record) error

//ScanFuncBatch is the type of function called for read on a slice of records
//from the steram. The Record argument contains the batch of the last unseen records
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

steram -> stream

// If an error is returned, scanning stops. The sole exception is when the
// function returns the special value ErrSkipCheckpoint.

type ScanFuncBatch func([]*kinesis.Record) error

// ErrSkipCheckpoint is used as a return value from ScanFunc to indicate that
// the current checkpoint should be skipped skipped. It is not returned
// as an error by any function.
Expand Down Expand Up @@ -138,7 +148,49 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
return <-errc
}

// ScanShard loops over records on a specific shard, calls the callback func
//ScanBatch performs scan function using intereval batching for invoking callback function
func (c *Consumer) ScanBatch(ctx context.Context, fn ScanFuncBatch) error {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be ScanBatchFunc

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var (
errc = make(chan error, 1)
shardc = make(chan *kinesis.Shard, 1)
)

go func() {
c.group.Start(ctx, shardc)
<-ctx.Done()
close(shardc)
}()

wg := new(sync.WaitGroup)
// process each of the shards
for shard := range shardc {
wg.Add(1)
go func(shardID string) {
defer wg.Done()
if err := c.ScanShardWithIntervalBatching(ctx, shardID, fn); err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
cancel()
default:
// error has already occurred
}
}
}(aws.StringValue(shard.ShardId))
}

go func() {
wg.Wait()
close(errc)
}()

return <-errc
}

// ScanShardContinuous loops over records on a specific shard, calls the callback func
// for each record and checkpoints the progress of scan.
func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) error {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not break the public API if possible. lemme put some thought into this one

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good : )

// get last seq number from checkpoint
Expand Down Expand Up @@ -231,6 +283,111 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}
}

// ScanShardWithIntervalBatching waits for a specific second time interval to occur and then performs a batch fetch with specific shard, calls the callback func
// with an array of records and checkpoints the scans

func (c *Consumer) ScanShardWithIntervalBatching(ctx context.Context, shardID string, fn ScanFuncBatch) error {
// get last seq number from checkpoint

lastSeqNum, err := c.group.GetCheckpoint(c.streamName, shardID)
if err != nil {
return fmt.Errorf("get checkpoint error: %v", err)
}

// get shard iterator
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}

c.logger.Log("[CONSUMER] start scan:", shardID, lastSeqNum)
defer func() {
c.logger.Log("[CONSUMER] stop scan:", shardID)
}()
scanTicker := time.NewTicker(c.scanInterval)
defer scanTicker.Stop()
time_start := time.Now()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you run a linter on this code you'll see a bunch of warnings around CamelCase vs snake_case

for {

if time_elapsed := time.Since(time_start); int64(time_elapsed.Seconds()) <= c.batchInterval {
Copy link
Owner

@harlow harlow Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing? is it different from the scan interval?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Effectively just a poorer duplication of scan interval, when I wrote this initially I overlooked that component of the code... thank you for pointing this out

continue
}

time_start = time.Now()
resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{
Limit: aws.Int64(c.maxRecords),
ShardIterator: shardIterator,
})

// attempt to recover from GetRecords error when expired iterator
if err != nil {
c.logger.Log("[CONSUMER] get records error:", err.Error())

if awserr, ok := err.(awserr.Error); ok {
if _, ok := retriableErrors[awserr.Code()]; !ok {
return fmt.Errorf("get records error: %v", awserr.Message())
}
}

shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
} else {
// loop over records, call callback func
var records []*kinesis.Record
var err error
if c.isAggregated {
records, err = deaggregator.DeaggregateRecords(resp.Records)
if err != nil {
return err
}
} else {
records = resp.Records
}

records_length := len(records)

if records_length == 0 {
err = fn(nil)
continue

} else {
err = fn(records)
}
last_record := records[records_length-1]

if err != nil && err != ErrSkipCheckpoint {
return err
}
lastSeqNum = *last_record.SequenceNumber

if err != ErrSkipCheckpoint {
if err := c.group.SetCheckpoint(c.streamName, shardID, *last_record.SequenceNumber); err != nil {
return err
}
}

c.counter.Add("records", int64(records_length))

if isShardClosed(resp.NextShardIterator, shardIterator) {
c.logger.Log("[CONSUMER] shard closed:", shardID)
return nil
}

shardIterator = resp.NextShardIterator
}

// Wait for next scan
select {
case <-ctx.Done():
return nil
case <-scanTicker.C:
continue
}
}
}

var retriableErrors = map[string]struct{}{
kinesis.ErrCodeExpiredIteratorException: struct{}{},
kinesis.ErrCodeProvisionedThroughputExceededException: struct{}{},
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ func WithAggregation(a bool) Option {
c.isAggregated = a
}
}

//WithBatchSecondInterval overrides the batch retrieval interval for the consumer
func WithBatchSecondInterval(k int64) Option {
return func(c *Consumer) {
c.batchInterval = k
}
}