diff --git a/go.sum b/go.sum index 1943dc8d..65f5adb4 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -160,6 +162,8 @@ golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sU golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -188,6 +192,8 @@ golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= diff --git a/internal/api/chat/create_conversation_message_stream_v2.go b/internal/api/chat/create_conversation_message_stream_v2.go index b82adf5d..8c715a36 100644 --- a/internal/api/chat/create_conversation_message_stream_v2.go +++ b/internal/api/chat/create_conversation_message_stream_v2.go @@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( APIKey: settings.OpenAIAPIKey, } - openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) + openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) if err != nil { return s.sendStreamError(stream, err) } @@ -307,7 +307,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( for i, bsonMsg := range conversation.InappChatHistory { protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg) } - title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider) + title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, protoMessages, llmProvider) if err != nil { s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex()) return diff --git a/internal/api/grpc.go b/internal/api/grpc.go index ed9dc2b0..3451d667 100644 --- a/internal/api/grpc.go +++ b/internal/api/grpc.go @@ -15,6 +15,7 @@ import ( chatv2 "paperdebugger/pkg/gen/api/chat/v2" commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" // "github.com/grpc-ecosystem/go-grpc-middleware" @@ -106,6 +107,7 @@ func NewGrpcServer( userServer userv1.UserServiceServer, projectServer projectv1.ProjectServiceServer, commentServer commentv1.CommentServiceServer, + usageServer usagev1.UsageServiceServer, ) *GrpcServer { grpcServer := &GrpcServer{} grpcServer.userService = userService @@ -121,5 +123,6 @@ func NewGrpcServer( userv1.RegisterUserServiceServer(grpcServer.Server, userServer) projectv1.RegisterProjectServiceServer(grpcServer.Server, projectServer) commentv1.RegisterCommentServiceServer(grpcServer.Server, commentServer) + usagev1.RegisterUsageServiceServer(grpcServer.Server, usageServer) return grpcServer } diff --git a/internal/api/server.go b/internal/api/server.go index b093c767..d8e9b36a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -17,6 +17,7 @@ import ( commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" sharedv1 "paperdebugger/pkg/gen/api/shared/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" "github.com/gin-gonic/gin" @@ -105,6 +106,11 @@ func (s *Server) Run(addr string) { s.logger.Fatalf("failed to register comment service grpc gateway: %v", err) return } + err = usagev1.RegisterUsageServiceHandler(context.Background(), mux, client) + if err != nil { + s.logger.Fatalf("failed to register usage service grpc gateway: %v", err) + return + } s.logger.Infof("[PAPERDEBUGGER] http server listening on %s", addr) s.ginServer.Any("/_pd/api/*path", func(c *gin.Context) { mux.ServeHTTP(c.Writer, c.Request) }) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go new file mode 100644 index 00000000..06a28718 --- /dev/null +++ b/internal/api/usage/get_session_usage.go @@ -0,0 +1,38 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +func (s *UsageServer) GetSessionUsage( + ctx context.Context, + req *usagev1.GetSessionUsageRequest, +) (*usagev1.GetSessionUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + session, err := s.usageService.GetActiveSession(ctx, actor.ID) + if err != nil { + return nil, err + } + + if session == nil { + return &usagev1.GetSessionUsageResponse{ + Session: nil, + }, nil + } + + return &usagev1.GetSessionUsageResponse{ + Session: &usagev1.SessionUsage{ + SessionExpiry: timestamppb.New(session.SessionExpiry.Time()), + TotalTokens: session.TotalTokens, + }, + }, nil +} diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go new file mode 100644 index 00000000..f87cad60 --- /dev/null +++ b/internal/api/usage/get_weekly_usage.go @@ -0,0 +1,29 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +func (s *UsageServer) GetWeeklyUsage( + ctx context.Context, + req *usagev1.GetWeeklyUsageRequest, +) (*usagev1.GetWeeklyUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + stats, err := s.usageService.GetWeeklyUsage(ctx, actor.ID) + if err != nil { + return nil, err + } + + return &usagev1.GetWeeklyUsageResponse{ + Usage: &usagev1.WeeklyUsage{ + TotalTokens: stats.TotalTokens, + }, + }, nil +} diff --git a/internal/api/usage/server.go b/internal/api/usage/server.go new file mode 100644 index 00000000..5d64854e --- /dev/null +++ b/internal/api/usage/server.go @@ -0,0 +1,24 @@ +package usage + +import ( + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/services" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +type UsageServer struct { + usagev1.UnimplementedUsageServiceServer + + usageService *services.UsageService + logger *logger.Logger +} + +func NewUsageServer( + usageService *services.UsageService, + logger *logger.Logger, +) usagev1.UsageServiceServer { + return &UsageServer{ + usageService: usageService, + logger: logger, + } +} diff --git a/internal/libs/db/db.go b/internal/libs/db/db.go index 52a5548c..8468f73c 100644 --- a/internal/libs/db/db.go +++ b/internal/libs/db/db.go @@ -6,6 +6,7 @@ import ( "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -43,5 +44,33 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) { } logger.Info("[MONGO] initialized") - return &DB{Client: client, cfg: cfg, logger: logger}, nil + + db := &DB{Client: client, cfg: cfg, logger: logger} + db.ensureIndexes() + return db, nil +} + +// ensureIndexes creates necessary indexes for the database collections. +func (db *DB) ensureIndexes() { + sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) + + // TTL index: auto-delete sessions after 30 days + _, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{{Key: "session_expiry", Value: 1}}, + Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60), + }) + if err != nil { + db.logger.Error("Failed to create TTL index on llm_sessions", "error", err) + } + + // Compound index for efficient active session lookups + _, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "session_expiry", Value: -1}, + }, + }) + if err != nil { + db.logger.Error("Failed to create compound index on llm_sessions", "error", err) + } } diff --git a/internal/models/usage.go b/internal/models/usage.go new file mode 100644 index 00000000..91d73273 --- /dev/null +++ b/internal/models/usage.go @@ -0,0 +1,19 @@ +package models + +import "go.mongodb.org/mongo-driver/v2/bson" + +// LLMSession represents a user's session for tracking LLM usage and token counts. +type LLMSession struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + SessionStart bson.DateTime `bson:"session_start"` + SessionExpiry bson.DateTime `bson:"session_expiry"` + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` +} + +func (s LLMSession) CollectionName() string { + return "llm_sessions" +} diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index 87a1e26a..4bbcf816 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -20,6 +20,7 @@ type AIClientV2 struct { reverseCommentService *services.ReverseCommentService projectService *services.ProjectService + usageService *services.UsageService cfg *cfg.Cfg logger *logger.Logger } @@ -60,6 +61,7 @@ func NewAIClientV2( reverseCommentService *services.ReverseCommentService, projectService *services.ProjectService, + usageService *services.UsageService, cfg *cfg.Cfg, logger *logger.Logger, ) *AIClientV2 { @@ -107,6 +109,7 @@ func NewAIClientV2( reverseCommentService: reverseCommentService, projectService: projectService, + usageService: usageService, cfg: cfg, logger: logger, } diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index f10082bf..2c8daa0e 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "paperdebugger/internal/models" + "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" chatv2 "paperdebugger/pkg/gen/api/chat/v2" "strings" "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" ) // define []openai.ChatCompletionMessageParamUnion as OpenAIChatHistory @@ -25,8 +27,8 @@ import ( // 1. The full chat history sent to the language model (including any tool call results). // 2. The incremental chat history visible to the user (including tool call results and assistant responses). // 3. An error, if any occurred during the process. -func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { - openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider) +func (a *AIClientV2) ChatCompletionV2(ctx context.Context, userID bson.ObjectID, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { + openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, userID, "", modelSlug, messages, llmProvider) if err != nil { return nil, nil, err } @@ -54,7 +56,7 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes // - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop. // - If no tool calls are needed, it appends the assistant's response and exits the loop. // - Finally, it returns the updated chat histories and any error encountered. -func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { +func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, userID bson.ObjectID, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { openaiChatHistory := messages inappChatHistory := AppChatHistory{} @@ -97,7 +99,22 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream if len(chunk.Choices) == 0 { // Handle usage information - // fmt.Printf("Usage: %+v\n", chunk.Usage) + if chunk.Usage.TotalTokens > 0 { + // Record usage asynchronously to avoid blocking the response + go func(usage services.UsageRecord) { + bgCtx := context.Background() + if err := a.usageService.RecordUsage(bgCtx, usage); err != nil { + a.logger.Error("Failed to store usage", "error", err) + return + } + + }(services.UsageRecord{ + UserID: userID, + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + }) + } continue } @@ -185,7 +202,6 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream // answer_content += chunk.Choices[0].Delta.Content // fmt.Printf("answer_content: %s\n", answer_content) streamHandler.HandleTextDoneItem(chunk, answer_content, reasoning_content) - break } } diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go index 1995d590..5cc43ce5 100644 --- a/internal/services/toolkit/client/get_citation_keys.go +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -241,7 +241,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI // Bibliography is placed at the start of the prompt to leverage prompt caching message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userId, "gpt-5.2", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/get_citation_keys_test.go b/internal/services/toolkit/client/get_citation_keys_test.go index 4d2a857d..802e6bbf 100644 --- a/internal/services/toolkit/client/get_citation_keys_test.go +++ b/internal/services/toolkit/client/get_citation_keys_test.go @@ -25,10 +25,12 @@ func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService } projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + usageService := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger()) aiClient := client.NewAIClientV2( dbInstance, &services.ReverseCommentService{}, projectService, + usageService, cfg.GetCfg(), logger.GetLogger(), ) diff --git a/internal/services/toolkit/client/get_conversation_title_v2.go b/internal/services/toolkit/client/get_conversation_title_v2.go index 6c92f0c2..f3fd5c8c 100644 --- a/internal/services/toolkit/client/get_conversation_title_v2.go +++ b/internal/services/toolkit/client/get_conversation_title_v2.go @@ -11,9 +11,10 @@ import ( "github.com/openai/openai-go/v3" "github.com/samber/lo" + "go.mongodb.org/mongo-driver/v2/bson" ) -func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { +func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, userID bson.ObjectID, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string { if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok { return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent()) @@ -29,7 +30,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor message := strings.Join(messages, "\n") message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userID, "gpt-5-nano", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index 69e73071..47829575 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -74,6 +74,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } } @@ -85,6 +88,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), // Tool registration is managed centrally by the registry ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), // Must set to false, because we are construct our own chat history. + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } diff --git a/internal/services/usage.go b/internal/services/usage.go new file mode 100644 index 00000000..d40a7156 --- /dev/null +++ b/internal/services/usage.go @@ -0,0 +1,175 @@ +package services + +import ( + "context" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +const SessionDuration = 5 * time.Hour + +type UsageService struct { + BaseService + sessionCollection *mongo.Collection +} + +type UsageRecord struct { + UserID bson.ObjectID + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 +} + +type UsageStats struct { + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` + SessionCount int64 `bson:"session_count"` +} + +func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService { + base := NewBaseService(db, cfg, logger) + return &UsageService{ + BaseService: base, + sessionCollection: base.db.Collection((models.LLMSession{}).CollectionName()), + } +} + +// RecordUsage updates the active session or creates a new one if none exists. +// Falls back to update if insert fails (handles race when another request created a session). +func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) error { + now := time.Now() + nowBson := bson.DateTime(now.UnixMilli()) + + filter := bson.M{ + "user_id": record.UserID, + "session_expiry": bson.M{"$gt": nowBson}, + } + update := bson.M{ + "$inc": bson.M{ + "prompt_tokens": record.PromptTokens, + "completion_tokens": record.CompletionTokens, + "total_tokens": record.TotalTokens, + "request_count": 1, + }, + } + + result, err := s.sessionCollection.UpdateOne(ctx, filter, update) + if err != nil { + return err + } + if result.MatchedCount > 0 { + return nil + } + + // No active session found - create a new one + session := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: record.UserID, + SessionStart: nowBson, + SessionExpiry: bson.DateTime(now.Add(SessionDuration).UnixMilli()), + PromptTokens: record.PromptTokens, + CompletionTokens: record.CompletionTokens, + TotalTokens: record.TotalTokens, + RequestCount: 1, + } + _, err = s.sessionCollection.InsertOne(ctx, session) + if err != nil { + // Insert failed (race condition or other error) - retry update + _, err = s.sessionCollection.UpdateOne(ctx, filter, update) + } + return err +} + +// GetActiveSession returns the current active session for a user, if any. +func (s *UsageService) GetActiveSession(ctx context.Context, userID bson.ObjectID) (*models.LLMSession, error) { + now := bson.DateTime(time.Now().UnixMilli()) + filter := bson.M{ + "user_id": userID, + "session_expiry": bson.M{"$gt": now}, + } + + var session models.LLMSession + err := s.sessionCollection.FindOne(ctx, filter).Decode(&session) + if err == mongo.ErrNoDocuments { + return nil, nil + } + if err != nil { + return nil, err + } + return &session, nil +} + +// GetWeeklyUsage returns aggregated usage for a user for the current week (Monday-Sunday). +func (s *UsageService) GetWeeklyUsage(ctx context.Context, userID bson.ObjectID) (*UsageStats, error) { + weekStart := startOfWeek(time.Now()) + return s.getUsageSince(ctx, userID, weekStart) +} + +func (s *UsageService) getUsageSince(ctx context.Context, userID bson.ObjectID, since time.Time) (*UsageStats, error) { + pipeline := bson.A{ + bson.M{"$match": bson.M{ + "user_id": userID, + "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, + }}, + bson.M{"$group": bson.M{ + "_id": nil, + "prompt_tokens": bson.M{"$sum": "$prompt_tokens"}, + "completion_tokens": bson.M{"$sum": "$completion_tokens"}, + "total_tokens": bson.M{"$sum": "$total_tokens"}, + "request_count": bson.M{"$sum": "$request_count"}, + "session_count": bson.M{"$sum": 1}, + }}, + } + + cursor, err := s.sessionCollection.Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + if cursor.Next(ctx) { + var result UsageStats + if err := cursor.Decode(&result); err != nil { + return nil, err + } + return &result, nil + } + return &UsageStats{}, nil +} + +// startOfWeek returns the start of the week (Monday 00:00:00 UTC). +func startOfWeek(t time.Time) time.Time { + t = t.UTC() + daysFromMonday := (int(t.Weekday()) + 6) % 7 // Sunday=6, Monday=0, Tuesday=1, ... + return time.Date(t.Year(), t.Month(), t.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) +} + +// ListRecentSessions returns the most recent sessions for a user. +func (s *UsageService) ListRecentSessions(ctx context.Context, userID bson.ObjectID, limit int64) ([]models.LLMSession, error) { + filter := bson.M{"user_id": userID} + opts := options.Find(). + SetSort(bson.D{{Key: "session_start", Value: -1}}). + SetLimit(limit) + + cursor, err := s.sessionCollection.Find(ctx, filter, opts) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var sessions []models.LLMSession + if err := cursor.All(ctx, &sessions); err != nil { + return nil, err + } + return sessions, nil +} diff --git a/internal/wire.go b/internal/wire.go index f823bc2e..52e6ff28 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -9,6 +9,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -32,6 +33,7 @@ var Set = wire.NewSet( user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, + usage.NewUsageServer, aiclient.NewAIClient, aiclient.NewAIClientV2, @@ -43,6 +45,7 @@ var Set = wire.NewSet( services.NewProjectService, services.NewPromptService, services.NewOAuthService, + services.NewUsageService, cfg.GetCfg, logger.GetLogger, diff --git a/internal/wire_gen.go b/internal/wire_gen.go index 75c4e91a..a706db0f 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -13,6 +13,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -38,14 +39,16 @@ func InitializeApp() (*api.Server, error) { aiClient := client.NewAIClient(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) chatService := services.NewChatService(dbDB, cfgCfg, loggerLogger) chatServiceServer := chat.NewChatServer(aiClient, chatService, projectService, userService, loggerLogger, cfgCfg) - aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) + usageService := services.NewUsageService(dbDB, cfgCfg, loggerLogger) + aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, usageService, cfgCfg, loggerLogger) chatServiceV2 := services.NewChatServiceV2(dbDB, cfgCfg, loggerLogger) chatv2ChatServiceServer := chat.NewChatServerV2(aiClientV2, chatServiceV2, projectService, userService, loggerLogger, cfgCfg) promptService := services.NewPromptService(dbDB, cfgCfg, loggerLogger) userServiceServer := user.NewUserServer(userService, promptService, cfgCfg, loggerLogger) projectServiceServer := project.NewProjectServer(projectService, loggerLogger, cfgCfg) commentServiceServer := comment.NewCommentServer(projectService, chatService, reverseCommentService, loggerLogger, cfgCfg) - grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer) + usageServiceServer := usage.NewUsageServer(usageService, loggerLogger) + grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer, usageServiceServer) oAuthService := services.NewOAuthService(dbDB, cfgCfg, loggerLogger) oAuthHandler := auth.NewOAuthHandler(oAuthService) ginServer := api.NewGinServer(cfgCfg, oAuthHandler) @@ -55,4 +58,4 @@ func InitializeApp() (*api.Server, error) { // wire.go: -var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, cfg.GetCfg, logger.GetLogger, db.NewDB) +var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, usage.NewUsageServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, cfg.GetCfg, logger.GetLogger, db.NewDB) diff --git a/pkg/gen/api/chat/v2/chat.pb.go b/pkg/gen/api/chat/v2/chat.pb.go index 0d312c55..485bfd0f 100644 --- a/pkg/gen/api/chat/v2/chat.pb.go +++ b/pkg/gen/api/chat/v2/chat.pb.go @@ -7,13 +7,12 @@ package chatv2 import ( - reflect "reflect" - sync "sync" - unsafe "unsafe" - _ "google.golang.org/genproto/googleapis/api/annotations" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" ) const ( diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go new file mode 100644 index 00000000..1fcf6299 --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -0,0 +1,364 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type SessionUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` + TotalTokens int64 `protobuf:"varint,2,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SessionUsage) Reset() { + *x = SessionUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SessionUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SessionUsage) ProtoMessage() {} + +func (x *SessionUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SessionUsage.ProtoReflect.Descriptor instead. +func (*SessionUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} +} + +func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { + if x != nil { + return x.SessionExpiry + } + return nil +} + +func (x *SessionUsage) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +type WeeklyUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + TotalTokens int64 `protobuf:"varint,1,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WeeklyUsage) Reset() { + *x = WeeklyUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WeeklyUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WeeklyUsage) ProtoMessage() {} + +func (x *WeeklyUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WeeklyUsage.ProtoReflect.Descriptor instead. +func (*WeeklyUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} +} + +func (x *WeeklyUsage) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +type GetSessionUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageRequest) Reset() { + *x = GetSessionUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageRequest) ProtoMessage() {} + +func (x *GetSessionUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageRequest.ProtoReflect.Descriptor instead. +func (*GetSessionUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{2} +} + +type GetSessionUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Active session usage, null if no active session + Session *SessionUsage `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageResponse) Reset() { + *x = GetSessionUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageResponse) ProtoMessage() {} + +func (x *GetSessionUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageResponse.ProtoReflect.Descriptor instead. +func (*GetSessionUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{3} +} + +func (x *GetSessionUsageResponse) GetSession() *SessionUsage { + if x != nil { + return x.Session + } + return nil +} + +type GetWeeklyUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageRequest) Reset() { + *x = GetWeeklyUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageRequest) ProtoMessage() {} + +func (x *GetWeeklyUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageRequest.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{4} +} + +type GetWeeklyUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Usage *WeeklyUsage `protobuf:"bytes,1,opt,name=usage,proto3" json:"usage,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageResponse) Reset() { + *x = GetWeeklyUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageResponse) ProtoMessage() {} + +func (x *GetWeeklyUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageResponse.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{5} +} + +func (x *GetWeeklyUsageResponse) GetUsage() *WeeklyUsage { + if x != nil { + return x.Usage + } + return nil +} + +var File_usage_v1_usage_proto protoreflect.FileDescriptor + +const file_usage_v1_usage_proto_rawDesc = "" + + "\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"t\n" + + "\fSessionUsage\x12A\n" + + "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12!\n" + + "\ftotal_tokens\x18\x02 \x01(\x03R\vtotalTokens\"0\n" + + "\vWeeklyUsage\x12!\n" + + "\ftotal_tokens\x18\x01 \x01(\x03R\vtotalTokens\"\x18\n" + + "\x16GetSessionUsageRequest\"K\n" + + "\x17GetSessionUsageResponse\x120\n" + + "\asession\x18\x01 \x01(\v2\x16.usage.v1.SessionUsageR\asession\"\x17\n" + + "\x15GetWeeklyUsageRequest\"E\n" + + "\x16GetWeeklyUsageResponse\x12+\n" + + "\x05usage\x18\x01 \x01(\v2\x15.usage.v1.WeeklyUsageR\x05usage2\x9a\x02\n" + + "\fUsageService\x12\x85\x01\n" + + "\x0fGetSessionUsage\x12 .usage.v1.GetSessionUsageRequest\x1a!.usage.v1.GetSessionUsageResponse\"-\x82\xd3\xe4\x93\x02'\x12%/_pd/api/v1/users/@self/usage/session\x12\x81\x01\n" + + "\x0eGetWeeklyUsage\x12\x1f.usage.v1.GetWeeklyUsageRequest\x1a .usage.v1.GetWeeklyUsageResponse\",\x82\xd3\xe4\x93\x02&\x12$/_pd/api/v1/users/@self/usage/weeklyB\x87\x01\n" + + "\fcom.usage.v1B\n" + + "UsageProtoP\x01Z*paperdebugger/pkg/gen/api/usage/v1;usagev1\xa2\x02\x03UXX\xaa\x02\bUsage.V1\xca\x02\bUsage\\V1\xe2\x02\x14Usage\\V1\\GPBMetadata\xea\x02\tUsage::V1b\x06proto3" + +var ( + file_usage_v1_usage_proto_rawDescOnce sync.Once + file_usage_v1_usage_proto_rawDescData []byte +) + +func file_usage_v1_usage_proto_rawDescGZIP() []byte { + file_usage_v1_usage_proto_rawDescOnce.Do(func() { + file_usage_v1_usage_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc))) + }) + return file_usage_v1_usage_proto_rawDescData +} + +var file_usage_v1_usage_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_usage_v1_usage_proto_goTypes = []any{ + (*SessionUsage)(nil), // 0: usage.v1.SessionUsage + (*WeeklyUsage)(nil), // 1: usage.v1.WeeklyUsage + (*GetSessionUsageRequest)(nil), // 2: usage.v1.GetSessionUsageRequest + (*GetSessionUsageResponse)(nil), // 3: usage.v1.GetSessionUsageResponse + (*GetWeeklyUsageRequest)(nil), // 4: usage.v1.GetWeeklyUsageRequest + (*GetWeeklyUsageResponse)(nil), // 5: usage.v1.GetWeeklyUsageResponse + (*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp +} +var file_usage_v1_usage_proto_depIdxs = []int32{ + 6, // 0: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp + 0, // 1: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage + 1, // 2: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage + 2, // 3: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest + 4, // 4: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest + 3, // 5: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse + 5, // 6: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse + 5, // [5:7] is the sub-list for method output_type + 3, // [3:5] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_usage_v1_usage_proto_init() } +func file_usage_v1_usage_proto_init() { + if File_usage_v1_usage_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc)), + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_usage_v1_usage_proto_goTypes, + DependencyIndexes: file_usage_v1_usage_proto_depIdxs, + MessageInfos: file_usage_v1_usage_proto_msgTypes, + }.Build() + File_usage_v1_usage_proto = out.File + file_usage_v1_usage_proto_goTypes = nil + file_usage_v1_usage_proto_depIdxs = nil +} diff --git a/pkg/gen/api/usage/v1/usage.pb.gw.go b/pkg/gen/api/usage/v1/usage.pb.gw.go new file mode 100644 index 00000000..3a455736 --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.gw.go @@ -0,0 +1,211 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: usage/v1/usage.proto + +/* +Package usagev1 is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package usagev1 + +import ( + "context" + "errors" + "io" + "net/http" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// Suppress "imported and not used" errors +var ( + _ codes.Code + _ io.Reader + _ status.Status + _ = errors.New + _ = runtime.String + _ = utilities.NewDoubleArray + _ = metadata.Join +) + +func request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetSessionUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetSessionUsage(ctx, &protoReq) + return msg, metadata, err +} + +func request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetWeeklyUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetWeeklyUsage(ctx, &protoReq) + return msg, metadata, err +} + +// RegisterUsageServiceHandlerServer registers the http handlers for service UsageService to "mux". +// UnaryRPC :call UsageServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterUsageServiceHandlerFromEndpoint instead. +// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call. +func RegisterUsageServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server UsageServiceServer) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + + return nil +} + +// RegisterUsageServiceHandlerFromEndpoint is same as RegisterUsageServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterUsageServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.NewClient(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + return RegisterUsageServiceHandler(ctx, mux, conn) +} + +// RegisterUsageServiceHandler registers the http handlers for service UsageService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterUsageServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterUsageServiceHandlerClient(ctx, mux, NewUsageServiceClient(conn)) +} + +// RegisterUsageServiceHandlerClient registers the http handlers for service UsageService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "UsageServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "UsageServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "UsageServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares. +func RegisterUsageServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client UsageServiceClient) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + return nil +} + +var ( + pattern_UsageService_GetSessionUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "session"}, "")) + pattern_UsageService_GetWeeklyUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "weekly"}, "")) +) + +var ( + forward_UsageService_GetSessionUsage_0 = runtime.ForwardResponseMessage + forward_UsageService_GetWeeklyUsage_0 = runtime.ForwardResponseMessage +) diff --git a/pkg/gen/api/usage/v1/usage_grpc.pb.go b/pkg/gen/api/usage/v1/usage_grpc.pb.go new file mode 100644 index 00000000..7d33c1dd --- /dev/null +++ b/pkg/gen/api/usage/v1/usage_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + UsageService_GetSessionUsage_FullMethodName = "/usage.v1.UsageService/GetSessionUsage" + UsageService_GetWeeklyUsage_FullMethodName = "/usage.v1.UsageService/GetWeeklyUsage" +) + +// UsageServiceClient is the client API for UsageService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type UsageServiceClient interface { + GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) + GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) +} + +type usageServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewUsageServiceClient(cc grpc.ClientConnInterface) UsageServiceClient { + return &usageServiceClient{cc} +} + +func (c *usageServiceClient) GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetSessionUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetSessionUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *usageServiceClient) GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetWeeklyUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetWeeklyUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// UsageServiceServer is the server API for UsageService service. +// All implementations must embed UnimplementedUsageServiceServer +// for forward compatibility. +type UsageServiceServer interface { + GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) + GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) + mustEmbedUnimplementedUsageServiceServer() +} + +// UnimplementedUsageServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedUsageServiceServer struct{} + +func (UnimplementedUsageServiceServer) GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetSessionUsage not implemented") +} +func (UnimplementedUsageServiceServer) GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetWeeklyUsage not implemented") +} +func (UnimplementedUsageServiceServer) mustEmbedUnimplementedUsageServiceServer() {} +func (UnimplementedUsageServiceServer) testEmbeddedByValue() {} + +// UnsafeUsageServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to UsageServiceServer will +// result in compilation errors. +type UnsafeUsageServiceServer interface { + mustEmbedUnimplementedUsageServiceServer() +} + +func RegisterUsageServiceServer(s grpc.ServiceRegistrar, srv UsageServiceServer) { + // If the following call panics, it indicates UnimplementedUsageServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&UsageService_ServiceDesc, srv) +} + +func _UsageService_GetSessionUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetSessionUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetSessionUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetSessionUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetSessionUsage(ctx, req.(*GetSessionUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _UsageService_GetWeeklyUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetWeeklyUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetWeeklyUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, req.(*GetWeeklyUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// UsageService_ServiceDesc is the grpc.ServiceDesc for UsageService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var UsageService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "usage.v1.UsageService", + HandlerType: (*UsageServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetSessionUsage", + Handler: _UsageService_GetSessionUsage_Handler, + }, + { + MethodName: "GetWeeklyUsage", + Handler: _UsageService_GetWeeklyUsage_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "usage/v1/usage.proto", +} diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto new file mode 100644 index 00000000..d9141dd0 --- /dev/null +++ b/proto/usage/v1/usage.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package usage.v1; + +import "google/api/annotations.proto"; +import "google/protobuf/timestamp.proto"; + +option go_package = "paperdebugger/pkg/gen/api/usage/v1;usagev1"; + +service UsageService { + rpc GetSessionUsage(GetSessionUsageRequest) returns (GetSessionUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/session"}; + } + + rpc GetWeeklyUsage(GetWeeklyUsageRequest) returns (GetWeeklyUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/weekly"}; + } +} + +message SessionUsage { + google.protobuf.Timestamp session_expiry = 1; + int64 total_tokens = 2; +} + +message WeeklyUsage { + int64 total_tokens = 1; +} + +message GetSessionUsageRequest {} + +message GetSessionUsageResponse { + // Active session usage, null if no active session + SessionUsage session = 1; +} + +message GetWeeklyUsageRequest {} + +message GetWeeklyUsageResponse { + WeeklyUsage usage = 1; +} diff --git a/webapp/_webapp/src/paperdebugger.tsx b/webapp/_webapp/src/paperdebugger.tsx index 5cdc5e5d..172a897e 100644 --- a/webapp/_webapp/src/paperdebugger.tsx +++ b/webapp/_webapp/src/paperdebugger.tsx @@ -2,6 +2,7 @@ import { Chat } from "./views/chat"; import { Tabs } from "./components/tabs"; import { Settings } from "./views/settings"; import { Prompts } from "./views/prompts"; +import { Usage } from "./views/usage"; import { PdAppBodyContainer } from "./components/pd-app-body-container"; export const PaperDebugger = () => { @@ -23,6 +24,13 @@ export const PaperDebugger = () => { children: , tooltip: "Prompt Library", }, + { + key: "usage", + title: "Usage", + icon: "tabler:chart-bar", + children: , + tooltip: "Usage Statistics", + }, { key: "settings", title: "Settings", diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts new file mode 100644 index 00000000..35ec21ae --- /dev/null +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -0,0 +1,141 @@ +// @generated by protoc-gen-es v2.11.0 with parameter "target=ts" +// @generated from file usage/v1/usage.proto (package usage.v1, syntax proto3) +/* eslint-disable */ + +import type { GenFile, GenMessage, GenService } from "@bufbuild/protobuf/codegenv2"; +import { fileDesc, messageDesc, serviceDesc } from "@bufbuild/protobuf/codegenv2"; +import { file_google_api_annotations } from "@buf/googleapis_googleapis.bufbuild_es/google/api/annotations_pb"; +import type { Timestamp } from "@bufbuild/protobuf/wkt"; +import { file_google_protobuf_timestamp } from "@bufbuild/protobuf/wkt"; +import type { Message } from "@bufbuild/protobuf"; + +/** + * Describes the file usage/v1/usage.proto. + */ +export const file_usage_v1_usage: GenFile = /*@__PURE__*/ + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEiWAoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIUCgx0b3RhbF90b2tlbnMYAiABKAMiIwoLV2Vla2x5VXNhZ2USFAoMdG90YWxfdG9rZW5zGAEgASgDIhgKFkdldFNlc3Npb25Vc2FnZVJlcXVlc3QiQgoXR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2USJwoHc2Vzc2lvbhgBIAEoCzIWLnVzYWdlLnYxLlNlc3Npb25Vc2FnZSIXChVHZXRXZWVrbHlVc2FnZVJlcXVlc3QiPgoWR2V0V2Vla2x5VXNhZ2VSZXNwb25zZRIkCgV1c2FnZRgBIAEoCzIVLnVzYWdlLnYxLldlZWtseVVzYWdlMpoCCgxVc2FnZVNlcnZpY2UShQEKD0dldFNlc3Npb25Vc2FnZRIgLnVzYWdlLnYxLkdldFNlc3Npb25Vc2FnZVJlcXVlc3QaIS51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXNwb25zZSItgtPkkwInEiUvX3BkL2FwaS92MS91c2Vycy9Ac2VsZi91c2FnZS9zZXNzaW9uEoEBCg5HZXRXZWVrbHlVc2FnZRIfLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVxdWVzdBogLnVzYWdlLnYxLkdldFdlZWtseVVzYWdlUmVzcG9uc2UiLILT5JMCJhIkL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvd2Vla2x5QocBCgxjb20udXNhZ2UudjFCClVzYWdlUHJvdG9QAVoqcGFwZXJkZWJ1Z2dlci9wa2cvZ2VuL2FwaS91c2FnZS92MTt1c2FnZXYxogIDVVhYqgIIVXNhZ2UuVjHKAghVc2FnZVxWMeICFFVzYWdlXFYxXEdQQk1ldGFkYXRh6gIJVXNhZ2U6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_protobuf_timestamp]); + +/** + * @generated from message usage.v1.SessionUsage + */ +export type SessionUsage = Message<"usage.v1.SessionUsage"> & { + /** + * @generated from field: google.protobuf.Timestamp session_expiry = 1; + */ + sessionExpiry?: Timestamp; + + /** + * @generated from field: int64 total_tokens = 2; + */ + totalTokens: bigint; +}; + +/** + * Describes the message usage.v1.SessionUsage. + * Use `create(SessionUsageSchema)` to create a new message. + */ +export const SessionUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 0); + +/** + * @generated from message usage.v1.WeeklyUsage + */ +export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { + /** + * @generated from field: int64 total_tokens = 1; + */ + totalTokens: bigint; +}; + +/** + * Describes the message usage.v1.WeeklyUsage. + * Use `create(WeeklyUsageSchema)` to create a new message. + */ +export const WeeklyUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 1); + +/** + * @generated from message usage.v1.GetSessionUsageRequest + */ +export type GetSessionUsageRequest = Message<"usage.v1.GetSessionUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetSessionUsageRequest. + * Use `create(GetSessionUsageRequestSchema)` to create a new message. + */ +export const GetSessionUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 2); + +/** + * @generated from message usage.v1.GetSessionUsageResponse + */ +export type GetSessionUsageResponse = Message<"usage.v1.GetSessionUsageResponse"> & { + /** + * Active session usage, null if no active session + * + * @generated from field: usage.v1.SessionUsage session = 1; + */ + session?: SessionUsage; +}; + +/** + * Describes the message usage.v1.GetSessionUsageResponse. + * Use `create(GetSessionUsageResponseSchema)` to create a new message. + */ +export const GetSessionUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 3); + +/** + * @generated from message usage.v1.GetWeeklyUsageRequest + */ +export type GetWeeklyUsageRequest = Message<"usage.v1.GetWeeklyUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageRequest. + * Use `create(GetWeeklyUsageRequestSchema)` to create a new message. + */ +export const GetWeeklyUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 4); + +/** + * @generated from message usage.v1.GetWeeklyUsageResponse + */ +export type GetWeeklyUsageResponse = Message<"usage.v1.GetWeeklyUsageResponse"> & { + /** + * @generated from field: usage.v1.WeeklyUsage usage = 1; + */ + usage?: WeeklyUsage; +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageResponse. + * Use `create(GetWeeklyUsageResponseSchema)` to create a new message. + */ +export const GetWeeklyUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 5); + +/** + * @generated from service usage.v1.UsageService + */ +export const UsageService: GenService<{ + /** + * @generated from rpc usage.v1.UsageService.GetSessionUsage + */ + getSessionUsage: { + methodKind: "unary"; + input: typeof GetSessionUsageRequestSchema; + output: typeof GetSessionUsageResponseSchema; + }, + /** + * @generated from rpc usage.v1.UsageService.GetWeeklyUsage + */ + getWeeklyUsage: { + methodKind: "unary"; + input: typeof GetWeeklyUsageRequestSchema; + output: typeof GetWeeklyUsageResponseSchema; + }, +}> = /*@__PURE__*/ + serviceDesc(file_usage_v1_usage, 0); + diff --git a/webapp/_webapp/src/query/api.ts b/webapp/_webapp/src/query/api.ts index 4098a018..3ae67e4b 100644 --- a/webapp/_webapp/src/query/api.ts +++ b/webapp/_webapp/src/query/api.ts @@ -224,3 +224,29 @@ export const acceptComments = async (data: PlainMessage const response = await apiclient.post(`/comments/accepted`, data); return fromJson(CommentsAcceptedResponseSchema, response); }; + +// Usage +import { + GetSessionUsageResponseSchema, + GetWeeklyUsageResponseSchema, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; + +export const getSessionUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/session", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetSessionUsageResponseSchema, response); +}; + +export const getWeeklyUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/weekly", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetWeeklyUsageResponseSchema, response); +}; diff --git a/webapp/_webapp/src/query/index.ts b/webapp/_webapp/src/query/index.ts index 2c05d959..4c9ea5cc 100644 --- a/webapp/_webapp/src/query/index.ts +++ b/webapp/_webapp/src/query/index.ts @@ -22,6 +22,8 @@ import { upsertUserInstructions, getProjectInstructions, upsertProjectInstructions, + getSessionUsage, + getWeeklyUsage, } from "./api"; import { CreatePromptResponse, @@ -37,6 +39,10 @@ import { GetProjectInstructionsResponse, UpsertProjectInstructionsResponse, } from "../pkg/gen/apiclient/project/v1/project_pb"; +import { + GetSessionUsageResponse, + GetWeeklyUsageResponse, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; import { useAuthStore } from "../stores/auth-store"; export const useGetProjectQuery = (projectId: string, opts?: UseQueryOptionsOverride) => { @@ -166,3 +172,24 @@ export const useUpsertProjectInstructionsMutation = ( ...opts, }); }; + +// Usage +export const useGetSessionUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getSessionUsage().queryKey, + queryFn: () => getSessionUsage(), + enabled: !!user, + ...opts, + }); +}; + +export const useGetWeeklyUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getWeeklyUsage().queryKey, + queryFn: () => getWeeklyUsage(), + enabled: !!user, + ...opts, + }); +}; diff --git a/webapp/_webapp/src/query/keys.ts b/webapp/_webapp/src/query/keys.ts index e09bfd7e..dfa3fc34 100644 --- a/webapp/_webapp/src/query/keys.ts +++ b/webapp/_webapp/src/query/keys.ts @@ -5,6 +5,10 @@ export const queryKeys = createQueryKeyStore({ getUser: () => ["users", "@self"], getUserInstructions: () => ["users", "@self", "instructions"], }, + usage: { + getSessionUsage: () => ["users", "@self", "usage", "session"], + getWeeklyUsage: () => ["users", "@self", "usage", "weekly"], + }, prompts: { listPrompts: () => ["users", "@self", "prompts"], }, diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx new file mode 100644 index 00000000..36756be7 --- /dev/null +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -0,0 +1,161 @@ +import { Spinner, Button } from "@heroui/react"; +import { Icon } from "@iconify/react"; +import { useState, useEffect } from "react"; +import { TabHeader } from "../../components/tab-header"; +import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; +import CellWrapper from "../../components/cell-wrapper"; + +const formatNumber = (n: bigint | number | undefined): string => { + if (n === undefined) return "0"; + return Number(n).toLocaleString(); +}; + +const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { + if (!timestamp || !timestamp.seconds) return ""; + const expiryMs = Number(timestamp.seconds) * 1000; + const nowMs = Date.now(); + const diffMs = expiryMs - nowMs; + + if (diffMs <= 0) return ""; + + const totalMinutes = Math.floor(diffMs / 60000); + const hours = Math.floor(totalMinutes / 60); + const minutes = totalMinutes % 60; + + if (hours > 0) { + return `resets in ${hours} hr ${minutes} min`; + } + return `resets in ${minutes} min`; +}; + +const formatLastUpdated = (timestamp: number): string => { + const diffMs = Date.now() - timestamp; + const seconds = Math.floor(diffMs / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + + if (seconds < 10) return "just now"; + if (seconds < 60) return `${seconds} seconds ago`; + if (minutes === 1) return "1 minute ago"; + if (minutes < 60) return `${minutes} minutes ago`; + if (hours === 1) return "1 hour ago"; + return `${hours} hours ago`; +}; + +const SectionContainer = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const SectionTitle = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const StatItem = ({ label, value }: { label: string; value: string }) => { + return ( +
+ {label} + {value} +
+ ); +}; + +export const Usage = () => { + const { + data: sessionData, + isLoading: sessionLoading, + dataUpdatedAt: sessionUpdatedAt, + refetch: refetchSession, + isFetching: sessionFetching, + } = useGetSessionUsageQuery(); + const { + data: weeklyData, + isLoading: weeklyLoading, + refetch: refetchWeekly, + isFetching: weeklyFetching, + } = useGetWeeklyUsageQuery(); + + const [, setTick] = useState(0); + + // Update the "last updated" text periodically + useEffect(() => { + const interval = setInterval(() => setTick((t) => t + 1), 10000); + return () => clearInterval(interval); + }, []); + + const isLoading = sessionLoading || weeklyLoading; + const isFetching = sessionFetching || weeklyFetching; + + const handleRefresh = () => { + refetchSession(); + refetchWeekly(); + }; + + if (isLoading) { + return ( +
+ +
+ ); + } + + const session = sessionData?.session; + const weekly = weeklyData?.usage; + + return ( +
+ +
+ + + Current Session + {session?.sessionExpiry && ( + ({formatTimeRemaining(session.sessionExpiry)}) + )} + + {session ? ( + +
+ +
+
+ ) : ( + +
No active session
+
+ )} +
+ + + Weekly Limits + {weekly ? ( + +
+ +
+
+ ) : ( + +
No usage data available
+
+ )} +
+ +
+ + Last updated: {formatLastUpdated(sessionUpdatedAt)} + + +
+
+
+ ); +};