diff --git a/cmd/query/app/handler_archive_test.go b/cmd/query/app/handler_archive_test.go index 87ae3e304b2..45db5ea90ce 100644 --- a/cmd/query/app/handler_archive_test.go +++ b/cmd/query/app/handler_archive_test.go @@ -7,6 +7,7 @@ package app import ( "errors" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -114,6 +115,24 @@ func TestArchiveTrace_Success(t *testing.T) { }, querysvc.QueryServiceOptions{ArchiveSpanWriter: mockWriter}) } +func TestArchiveTrace_SucessWithTimeWindow(t *testing.T) { + mockWriter := &spanstoremocks.Writer{} + mockWriter.On("WriteSpan", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*model.Span")). + Return(nil).Times(2) + withTestServer(t, func(ts *testServer) { + expectedQuery := spanstore.GetTraceParameters{ + TraceID: mockTraceID, + StartTime: time.UnixMicro(1), + EndTime: time.UnixMicro(2), + } + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), expectedQuery). + Return(mockTrace, nil).Once() + var response structuredTraceResponse + err := postJSON(ts.server.URL+"/api/archive/"+mockTraceID.String()+"?start=1&end=2", []string{}, &response) + require.NoError(t, err) + }, querysvc.QueryServiceOptions{ArchiveSpanWriter: mockWriter}) +} + func TestArchiveTrace_WriteErrors(t *testing.T) { mockWriter := &spanstoremocks.Writer{} mockWriter.On("WriteSpan", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*model.Span")). @@ -126,3 +145,35 @@ func TestArchiveTrace_WriteErrors(t *testing.T) { require.EqualError(t, err, `500 error from server: {"data":null,"total":0,"limit":0,"offset":0,"errors":[{"code":500,"msg":"cannot save\ncannot save"}]}`+"\n") }, querysvc.QueryServiceOptions{ArchiveSpanWriter: mockWriter}) } + +func TestArchiveTrace_BadTimeWindow(t *testing.T) { + testCases := []struct { + name string + query string + }{ + { + name: "Bad start time", + query: "start=a", + }, + { + name: "Bad end time", + query: "end=b", + }, + } + mockWriter := &spanstoremocks.Writer{} + mockWriter.On("WriteSpan", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*model.Span")). + Return(nil).Times(2 * len(testCases)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + withTestServer(t, func(ts *testServer) { + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("spanstore.GetTraceParameters")). + Return(mockTrace, nil).Once() + var response structuredTraceResponse + err := postJSON(ts.server.URL+"/api/archive/"+mockTraceID.String()+"?"+tc.query, []string{}, &response) + require.Error(t, err) + require.ErrorContains(t, err, "400 error from server") + require.ErrorContains(t, err, "unable to parse param") + }, querysvc.QueryServiceOptions{ArchiveSpanWriter: mockWriter}) + }) + } +} diff --git a/cmd/query/app/http_handler.go b/cmd/query/app/http_handler.go index 157fc398626..02ed05c059f 100644 --- a/cmd/query/app/http_handler.go +++ b/cmd/query/app/http_handler.go @@ -231,7 +231,12 @@ func (aH *APIHandler) search(w http.ResponseWriter, r *http.Request) { var uiErrors []structuredError var tracesFromStorage []*model.Trace if len(tQuery.traceIDs) > 0 { - tracesFromStorage, uiErrors, err = aH.tracesByIDs(r.Context(), tQuery.traceIDs) + tracesFromStorage, uiErrors, err = aH.tracesByIDs( + r.Context(), + tQuery.traceIDs, + tQuery.StartTimeMin, + tQuery.StartTimeMax, + ) if aH.handleError(w, err, http.StatusInternalServerError) { return } @@ -262,13 +267,14 @@ func (aH *APIHandler) tracesToResponse(traces []*model.Trace, adjust bool, uiErr } } -func (aH *APIHandler) tracesByIDs(ctx context.Context, traceIDs []model.TraceID) ([]*model.Trace, []structuredError, error) { +func (aH *APIHandler) tracesByIDs(ctx context.Context, traceIDs []model.TraceID, startTime time.Time, endTime time.Time) ([]*model.Trace, []structuredError, error) { var traceErrors []structuredError retMe := make([]*model.Trace, 0, len(traceIDs)) for _, traceID := range traceIDs { - // TODO: add start time & end time query := spanstore.GetTraceParameters{ - TraceID: traceID, + TraceID: traceID, + StartTime: startTime, + EndTime: endTime, } if trc, err := aH.queryService.GetTrace(ctx, query); err != nil { if !errors.Is(err, spanstore.ErrTraceNotFound) { @@ -425,18 +431,46 @@ func (aH *APIHandler) parseTraceID(w http.ResponseWriter, r *http.Request) (mode return traceID, true } +func (aH *APIHandler) parseMicroseconds(w http.ResponseWriter, r *http.Request, timeKey string) (time.Time, bool) { + if timeString := r.FormValue(timeKey); timeString != "" { + t, err := aH.queryParser.parseTime(r, timeKey, time.Microsecond) + if aH.handleError(w, err, http.StatusBadRequest) { + return time.Time{}, false + } + return t, true + } + // It's OK if no time is specified, since it's optional + return time.Time{}, true +} + +func (aH *APIHandler) parseGetTraceParameters(w http.ResponseWriter, r *http.Request) (spanstore.GetTraceParameters, bool) { + query := spanstore.GetTraceParameters{} + traceID, ok := aH.parseTraceID(w, r) + if !ok { + return query, false + } + startTime, ok := aH.parseMicroseconds(w, r, startTimeParam) + if !ok { + return query, false + } + endTime, ok := aH.parseMicroseconds(w, r, endTimeParam) + if !ok { + return query, false + } + query.TraceID = traceID + query.StartTime = startTime + query.EndTime = endTime + return query, true +} + // getTrace implements the REST API /traces/{trace-id} // It parses trace ID from the path, fetches the trace from QueryService, // formats it in the UI JSON format, and responds to the client. func (aH *APIHandler) getTrace(w http.ResponseWriter, r *http.Request) { - traceID, ok := aH.parseTraceID(w, r) + query, ok := aH.parseGetTraceParameters(w, r) if !ok { return } - // TODO: add start time & end time - query := spanstore.GetTraceParameters{ - TraceID: traceID, - } trc, err := aH.queryService.GetTrace(r.Context(), query) if errors.Is(err, spanstore.ErrTraceNotFound) { aH.handleError(w, err, http.StatusNotFound) @@ -460,16 +494,12 @@ func shouldAdjust(r *http.Request) bool { // archiveTrace implements the REST API POST:/archive/{trace-id}. // It passes the traceID to queryService.ArchiveTrace for writing. func (aH *APIHandler) archiveTrace(w http.ResponseWriter, r *http.Request) { - traceID, ok := aH.parseTraceID(w, r) + query, ok := aH.parseGetTraceParameters(w, r) if !ok { return } // QueryService.ArchiveTrace can now archive this traceID. - // TODO: add start time & end time - query := spanstore.GetTraceParameters{ - TraceID: traceID, - } err := aH.queryService.ArchiveTrace(r.Context(), query) if errors.Is(err, spanstore.ErrTraceNotFound) { aH.handleError(w, err, http.StatusNotFound) diff --git a/cmd/query/app/http_handler_test.go b/cmd/query/app/http_handler_test.go index 126e4c194bb..479466d0da0 100644 --- a/cmd/query/app/http_handler_test.go +++ b/cmd/query/app/http_handler_test.go @@ -198,6 +198,20 @@ func TestGetTraceDedupeSuccess(t *testing.T) { } } +func TestGetTraceWithTimeWindowSuccess(t *testing.T) { + ts := initializeTestServer(t) + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), spanstore.GetTraceParameters{ + TraceID: mockTraceID, + StartTime: time.UnixMicro(1), + EndTime: time.UnixMicro(2), + }).Return(mockTrace, nil).Once() + + var response structuredResponse + err := getJSON(ts.server.URL+`/api/traces/`+mockTraceID.String()+`?start=1&end=2`, &response) + require.NoError(t, err) + assert.Empty(t, response.Errors) +} + func TestLogOnServerError(t *testing.T) { zapCore, logs := observer.New(zap.InfoLevel) logger := zap.New(zapCore) @@ -388,6 +402,32 @@ func TestGetTraceBadTraceID(t *testing.T) { require.Error(t, err) } +func TestGetTraceBadTimeWindow(t *testing.T) { + testCases := []struct { + name string + query string + }{ + { + name: "Bad start time", + query: "start=a", + }, + { + name: "Bad end time", + query: "end=b", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ts := initializeTestServer(t) + var response structuredResponse + err := getJSON(ts.server.URL+`/api/traces/123456?`+tc.query, &response) + require.Error(t, err) + require.ErrorContains(t, err, "400 error from server") + require.ErrorContains(t, err, "unable to parse param") + }) + } +} + func TestSearchSuccess(t *testing.T) { ts := initializeTestServer(t) ts.spanReader.On("FindTraces", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*spanstore.TraceQueryParameters")). @@ -411,6 +451,57 @@ func TestSearchByTraceIDSuccess(t *testing.T) { assert.Len(t, response.Data, 2) } +func TestSearchByTraceIDWithTimeWindowSuccess(t *testing.T) { + ts := initializeTestServer(t) + expectedQuery1 := spanstore.GetTraceParameters{ + TraceID: mockTraceID, + StartTime: time.UnixMicro(1), + EndTime: time.UnixMicro(2), + } + traceId2 := model.NewTraceID(0, 456789) + expectedQuery2 := spanstore.GetTraceParameters{ + TraceID: traceId2, + StartTime: time.UnixMicro(1), + EndTime: time.UnixMicro(2), + } + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), expectedQuery1). + Return(mockTrace, nil) + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), expectedQuery2). + Return(mockTrace, nil) + + var response structuredResponse + err := getJSON(ts.server.URL+`/api/traces?traceID=`+mockTraceID.String()+`&traceID=`+traceId2.String()+`&start=1&end=2`, &response) + require.NoError(t, err) + assert.Empty(t, response.Errors) + assert.Len(t, response.Data, 2) +} + +func TestSearchTraceBadTimeWindow(t *testing.T) { + testCases := []struct { + name string + query string + }{ + { + name: "Bad start time", + query: "start=a", + }, + { + name: "Bad end time", + query: "end=b", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ts := initializeTestServer(t) + var response structuredResponse + err := getJSON(ts.server.URL+`/api/traces?traceID=1&traceID=2&`+tc.query, &response) + require.Error(t, err) + require.ErrorContains(t, err, "400 error from server") + require.ErrorContains(t, err, "unable to parse param") + }) + } +} + func TestSearchByTraceIDSuccessWithArchive(t *testing.T) { archiveReadMock := &spanstoremocks.Reader{} ts := initializeTestServerWithOptions(t, &tenancy.Manager{}, querysvc.QueryServiceOptions{ @@ -428,6 +519,28 @@ func TestSearchByTraceIDSuccessWithArchive(t *testing.T) { assert.Len(t, response.Data, 2) } +func TestSearchByTraceIDSuccessWithArchiveAndTimeWindow(t *testing.T) { + archiveReadMock := &spanstoremocks.Reader{} + ts := initializeTestServerWithOptions(t, &tenancy.Manager{}, querysvc.QueryServiceOptions{ + ArchiveSpanReader: archiveReadMock, + }) + expectedQuery := spanstore.GetTraceParameters{ + TraceID: mockTraceID, + StartTime: time.UnixMicro(1), + EndTime: time.UnixMicro(2), + } + ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), expectedQuery). + Return(nil, spanstore.ErrTraceNotFound) + archiveReadMock.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), expectedQuery). + Return(mockTrace, nil) + + var response structuredResponse + err := getJSON(ts.server.URL+`/api/traces?traceID=`+mockTraceID.String()+`&start=1&end=2`, &response) + require.NoError(t, err) + assert.Empty(t, response.Errors) + assert.Len(t, response.Data, 1) +} + func TestSearchByTraceIDNotFound(t *testing.T) { ts := initializeTestServer(t) ts.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("spanstore.GetTraceParameters")).