-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Support server to client requests (#11)
* This is a breaking change, but supports Server -> Client request and notifications.
- Loading branch information
1 parent
f79c24e
commit d661166
Showing
15 changed files
with
521 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package mcp | ||
|
||
import ( | ||
"context" | ||
"strconv" | ||
) | ||
|
||
type base struct { | ||
router *router | ||
stream Stream | ||
interceptors []Interceptor | ||
} | ||
|
||
func (b *base) listen(ctx context.Context, handler func(ctx context.Context, msg *Message) error) error { | ||
for { | ||
msg, err := b.stream.Recv() | ||
if err != nil { | ||
return err | ||
} | ||
if msg.Method != nil { | ||
go func() { | ||
handler(ctx, msg) | ||
}() | ||
} else { | ||
id, err := strconv.ParseUint(msg.ID.String(), 10, 64) | ||
if err != nil { | ||
continue | ||
} | ||
if inbox, ok := b.router.Remove(id); ok { | ||
inbox <- msg | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package mcp | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"strconv" | ||
) | ||
|
||
func call[P any, R any](ctx context.Context, c *base, method string, req *Request[P]) (*Response[R], error) { | ||
id, inbox := c.router.Add() | ||
|
||
var interceptor Interceptor | ||
if len(c.interceptors) > 0 { | ||
interceptor = newStack(c.interceptors) | ||
} else { | ||
interceptor = UnaryInterceptorFunc( | ||
func(next UnaryFunc) UnaryFunc { | ||
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { | ||
return next(ctx, request) | ||
}) | ||
}, | ||
) | ||
} | ||
|
||
inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { | ||
rawmsg, err := json.Marshal(req.Params) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
msgID := json.Number(request.ID()) | ||
msgVersion := "2.0" | ||
msgParams := json.RawMessage(rawmsg) | ||
|
||
msg := &Message{ | ||
ID: &msgID, | ||
JsonRPC: &msgVersion, | ||
Method: &method, | ||
Params: &msgParams, | ||
} | ||
|
||
if err := c.stream.Send(msg); err != nil { | ||
return nil, err | ||
} | ||
|
||
var result R | ||
|
||
select { | ||
case resp := <-inbox: | ||
if resp.Error != nil { | ||
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message)) | ||
} | ||
if resp.Result == nil { | ||
return nil, fmt.Errorf("no result") | ||
} | ||
if err := json.Unmarshal(*resp.Result, &result); err != nil { | ||
return nil, err | ||
} | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
|
||
return NewResponse(&result), nil | ||
}) | ||
|
||
req.id = strconv.FormatUint(id, 10) | ||
req.method = method | ||
|
||
resp, err := interceptor.WrapUnary(inner)(ctx, req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return resp.(*Response[R]), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,178 +1,107 @@ | ||
package mcp | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"strconv" | ||
"sync" | ||
|
||
"github.com/riza-io/mcp-go/internal/jsonrpc" | ||
) | ||
|
||
type Client interface { | ||
Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) | ||
ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) | ||
ListTools(ctx context.Context, req *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) | ||
CallTool(ctx context.Context, req *Request[CallToolRequest]) (*Response[CallToolResponse], error) | ||
ListPrompts(ctx context.Context, req *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) | ||
GetPrompt(ctx context.Context, req *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) | ||
ReadResource(ctx context.Context, req *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) | ||
ListResourceTemplates(ctx context.Context, req *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) | ||
Completion(ctx context.Context, req *Request[CompletionRequest]) (*Response[CompletionResponse], error) | ||
Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) | ||
SetLogLevel(ctx context.Context, req *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) | ||
} | ||
|
||
type StdioClient struct { | ||
in io.Reader | ||
out io.Writer | ||
scanner *bufio.Scanner | ||
next int | ||
lock sync.Mutex | ||
type ClientHandler interface { | ||
Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) | ||
Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) | ||
LogMessage(ctx context.Context, request *Request[LogMessageRequest]) | ||
} | ||
|
||
type UnimplementedClient struct{} | ||
|
||
func (u *UnimplementedClient) Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) { | ||
return nil, fmt.Errorf("not implemented") | ||
} | ||
|
||
func (u *UnimplementedClient) LogMessage(ctx context.Context, request *Request[LogMessageRequest]) { | ||
} | ||
|
||
func (c *UnimplementedClient) Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) { | ||
return NewResponse(&PingResponse{}), nil | ||
} | ||
|
||
type Client struct { | ||
handler ClientHandler | ||
interceptors []Interceptor | ||
base *base | ||
} | ||
|
||
func NewStdioClient(stdin io.Reader, stdout io.Writer, opts ...Option) Client { | ||
c := &StdioClient{ | ||
in: stdin, | ||
out: stdout, | ||
scanner: bufio.NewScanner(stdin), | ||
func NewClient(stream Stream, handler ClientHandler, opts ...Option) *Client { | ||
c := &Client{ | ||
handler: handler, | ||
} | ||
|
||
for _, opt := range opts { | ||
opt.applyToClient(c) | ||
} | ||
|
||
c.base = &base{ | ||
router: newRouter(), | ||
interceptors: c.interceptors, | ||
stream: stream, | ||
} | ||
return c | ||
} | ||
|
||
func clientCallUnary[P any, R any](ctx context.Context, c *StdioClient, method string, req *Request[P]) (*Response[R], error) { | ||
// Ensure that we are not sending multiple requests at the same time | ||
c.lock.Lock() | ||
defer c.lock.Unlock() | ||
|
||
defer func() { | ||
// Increment the ID counter | ||
c.next++ | ||
}() | ||
|
||
var interceptor Interceptor | ||
if len(c.interceptors) > 0 { | ||
interceptor = newStack(c.interceptors) | ||
} else { | ||
interceptor = UnaryInterceptorFunc( | ||
func(next UnaryFunc) UnaryFunc { | ||
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { | ||
return next(ctx, request) | ||
}) | ||
}, | ||
) | ||
} | ||
|
||
inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { | ||
rawmsg, err := json.Marshal(req.Params) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
msg := jsonrpc.Request{ | ||
ID: json.Number(request.ID()), | ||
JsonRPC: "2.0", | ||
Method: request.Method(), | ||
Params: json.RawMessage(rawmsg), | ||
} | ||
|
||
bs, err := json.Marshal(msg) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
fmt.Fprintln(c.out, string(bs)) | ||
|
||
var result R | ||
|
||
for c.scanner.Scan() { | ||
line := c.scanner.Bytes() | ||
|
||
var resp jsonrpc.Response | ||
|
||
if err := json.Unmarshal(line, &resp); err != nil { | ||
return nil, err | ||
} | ||
|
||
if resp.Error != nil { | ||
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message)) | ||
} | ||
|
||
if err := json.Unmarshal(resp.Result, &result); err != nil { | ||
return nil, err | ||
} | ||
|
||
break | ||
} | ||
|
||
if err := c.scanner.Err(); err != nil { | ||
return nil, err | ||
} | ||
|
||
return NewResponse(&result), nil | ||
}) | ||
|
||
req.id = strconv.Itoa(c.next) | ||
req.method = method | ||
// sync.Once? | ||
func (c *Client) Listen(ctx context.Context) error { | ||
return c.base.listen(ctx, c.processMessage) | ||
} | ||
|
||
resp, err := interceptor.WrapUnary(inner)(ctx, req) | ||
if err != nil { | ||
return nil, err | ||
func (c *Client) processMessage(ctx context.Context, msg *Message) error { | ||
srv := c.handler | ||
switch m := *msg.Method; m { | ||
case "ping": | ||
return process(ctx, c.base, msg, srv.Ping) | ||
case "notifications/message": | ||
return process(ctx, c.base, msg, noop(srv.LogMessage)) | ||
default: | ||
return fmt.Errorf("unknown method: %s", m) | ||
} | ||
|
||
return resp.(*Response[R]), nil | ||
} | ||
|
||
func (c *StdioClient) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) { | ||
return clientCallUnary[InitializeRequest, InitializeResponse](ctx, c, "initialize", request) | ||
func (c *Client) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) { | ||
return call[InitializeRequest, InitializeResponse](ctx, c.base, "initialize", request) | ||
} | ||
|
||
func (c *StdioClient) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) { | ||
return clientCallUnary[ListResourcesRequest, ListResourcesResponse](ctx, c, "resources/list", request) | ||
func (c *Client) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) { | ||
return call[ListResourcesRequest, ListResourcesResponse](ctx, c.base, "resources/list", request) | ||
} | ||
|
||
func (c *StdioClient) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) { | ||
return clientCallUnary[ListToolsRequest, ListToolsResponse](ctx, c, "tools/list", request) | ||
func (c *Client) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) { | ||
return call[ListToolsRequest, ListToolsResponse](ctx, c.base, "tools/list", request) | ||
} | ||
|
||
func (c *StdioClient) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) { | ||
return clientCallUnary[CallToolRequest, CallToolResponse](ctx, c, "tools/call", request) | ||
func (c *Client) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) { | ||
return call[CallToolRequest, CallToolResponse](ctx, c.base, "tools/call", request) | ||
} | ||
|
||
func (c *StdioClient) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) { | ||
return clientCallUnary[ListPromptsRequest, ListPromptsResponse](ctx, c, "prompts/list", request) | ||
func (c *Client) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) { | ||
return call[ListPromptsRequest, ListPromptsResponse](ctx, c.base, "prompts/list", request) | ||
} | ||
|
||
func (c *StdioClient) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) { | ||
return clientCallUnary[GetPromptRequest, GetPromptResponse](ctx, c, "prompts/get", request) | ||
func (c *Client) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) { | ||
return call[GetPromptRequest, GetPromptResponse](ctx, c.base, "prompts/get", request) | ||
} | ||
|
||
func (c *StdioClient) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) { | ||
return clientCallUnary[ReadResourceRequest, ReadResourceResponse](ctx, c, "resources/read", request) | ||
func (c *Client) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) { | ||
return call[ReadResourceRequest, ReadResourceResponse](ctx, c.base, "resources/read", request) | ||
} | ||
|
||
func (c *StdioClient) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) { | ||
return clientCallUnary[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c, "resources/templates/list", request) | ||
func (c *Client) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) { | ||
return call[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c.base, "resources/templates/list", request) | ||
} | ||
|
||
func (c *StdioClient) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) { | ||
return clientCallUnary[CompletionRequest, CompletionResponse](ctx, c, "completion", request) | ||
func (c *Client) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) { | ||
return call[CompletionRequest, CompletionResponse](ctx, c.base, "completion", request) | ||
} | ||
|
||
func (c *StdioClient) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) { | ||
return clientCallUnary[PingRequest, PingResponse](ctx, c, "ping", request) | ||
func (c *Client) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) { | ||
return call[PingRequest, PingResponse](ctx, c.base, "ping", request) | ||
} | ||
|
||
func (c *StdioClient) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) { | ||
return clientCallUnary[SetLogLevelRequest, SetLogLevelResponse](ctx, c, "logging/setLevel", request) | ||
func (c *Client) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) { | ||
return call[SetLogLevelRequest, SetLogLevelResponse](ctx, c.base, "logging/setLevel", request) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.