mirror of
https://github.com/Infisical/infisical.git
synced 2026-01-10 07:58:15 -05:00
draft: k8s operator changes
This commit is contained in:
@@ -126,7 +126,7 @@ export function createEventStreamClient(redis: Redis, options: IEventStreamClien
|
||||
|
||||
await redis.set(key, "1", "EX", 60);
|
||||
|
||||
stream.push("1");
|
||||
send({ type: "ping" });
|
||||
};
|
||||
|
||||
const close = () => {
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
|
||||
"k8s.io/apimachinery/pkg/api/errors"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
ctrl "sigs.k8s.io/controller-runtime"
|
||||
"sigs.k8s.io/controller-runtime/pkg/builder"
|
||||
"sigs.k8s.io/controller-runtime/pkg/client"
|
||||
"sigs.k8s.io/controller-runtime/pkg/event"
|
||||
"sigs.k8s.io/controller-runtime/pkg/handler"
|
||||
"sigs.k8s.io/controller-runtime/pkg/predicate"
|
||||
"sigs.k8s.io/controller-runtime/pkg/reconcile"
|
||||
"sigs.k8s.io/controller-runtime/pkg/source"
|
||||
|
||||
defaultErrors "errors"
|
||||
@@ -58,7 +60,6 @@ func (r *InfisicalSecretReconciler) GetLogger(req ctrl.Request) logr.Logger {
|
||||
// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.13.1/pkg/reconcile
|
||||
|
||||
func (r *InfisicalSecretReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
|
||||
|
||||
logger := r.GetLogger(req)
|
||||
|
||||
var infisicalSecretCRD secretsv1alpha1.InfisicalSecret
|
||||
@@ -204,8 +205,17 @@ func (r *InfisicalSecretReconciler) SetupWithManager(mgr ctrl.Manager) error {
|
||||
r.SourceCh = make(chan event.GenericEvent)
|
||||
|
||||
return ctrl.NewControllerManagedBy(mgr).
|
||||
Watches(
|
||||
&source.Channel{Source: r.SourceCh},
|
||||
handler.EnqueueRequestsFromMapFunc(r.findSecretsForCluster),
|
||||
).
|
||||
For(&secretsv1alpha1.InfisicalSecret{}, builder.WithPredicates(predicate.Funcs{
|
||||
GenericFunc: func(ge event.GenericEvent) bool {
|
||||
fmt.Println("Generic event recieved")
|
||||
return true
|
||||
},
|
||||
UpdateFunc: func(e event.UpdateEvent) bool {
|
||||
println("UpdateFunc event recieved")
|
||||
if e.ObjectOld.GetGeneration() == e.ObjectNew.GetGeneration() {
|
||||
return false // Skip reconciliation for status-only changes
|
||||
}
|
||||
@@ -228,9 +238,34 @@ func (r *InfisicalSecretReconciler) SetupWithManager(mgr ctrl.Manager) error {
|
||||
return true
|
||||
},
|
||||
})).
|
||||
Watches(
|
||||
&source.Channel{Source: r.SourceCh},
|
||||
&handler.EnqueueRequestForObject{},
|
||||
).
|
||||
Complete(r)
|
||||
|
||||
}
|
||||
|
||||
func (r *InfisicalSecretReconciler) findSecretsForCluster(o client.Object) []reconcile.Request {
|
||||
ctx := context.Background()
|
||||
secrets := &secretsv1alpha1.InfisicalSecretList{}
|
||||
|
||||
requests := []reconcile.Request{}
|
||||
|
||||
if err := r.List(ctx, secrets); err != nil {
|
||||
fmt.Println(err)
|
||||
return requests
|
||||
}
|
||||
|
||||
for _, sec := range secrets.Items {
|
||||
if sec.GetName() == o.GetName() && sec.GetNamespace() == o.GetNamespace() {
|
||||
requests = append(requests, reconcile.Request{
|
||||
NamespacedName: types.NamespacedName{
|
||||
Namespace: o.GetNamespace(),
|
||||
Name: o.GetName(),
|
||||
},
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(requests)
|
||||
|
||||
return requests
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
tpl "text/template"
|
||||
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
ctrl "sigs.k8s.io/controller-runtime"
|
||||
"sigs.k8s.io/controller-runtime/pkg/client"
|
||||
"sigs.k8s.io/controller-runtime/pkg/event"
|
||||
)
|
||||
|
||||
func (r *InfisicalSecretReconciler) handleAuthentication(ctx context.Context, infisicalSecret v1alpha1.InfisicalSecret, infisicalClient infisicalSdk.InfisicalClientInterface) (util.AuthenticationDetails, error) {
|
||||
@@ -464,10 +466,10 @@ func (r *InfisicalSecretReconciler) getResourceVariables(infisicalSecret v1alpha
|
||||
})
|
||||
|
||||
infisicalSecretResourceVariablesMap[string(infisicalSecret.UID)] = util.ResourceVariables{
|
||||
InfisicalClient: client,
|
||||
CancelCtx: cancel,
|
||||
AuthDetails: util.AuthenticationDetails{},
|
||||
EventStreamClient: sse.NewClient(api.API_HOST_URL),
|
||||
InfisicalClient: client,
|
||||
CancelCtx: cancel,
|
||||
AuthDetails: util.AuthenticationDetails{},
|
||||
ServerSentEvents: sse.NewConnectionRegistry(),
|
||||
}
|
||||
|
||||
resourceVariables = infisicalSecretResourceVariablesMap[string(infisicalSecret.UID)]
|
||||
@@ -477,7 +479,6 @@ func (r *InfisicalSecretReconciler) getResourceVariables(infisicalSecret v1alpha
|
||||
}
|
||||
|
||||
return resourceVariables
|
||||
|
||||
}
|
||||
|
||||
func (r *InfisicalSecretReconciler) updateResourceVariables(infisicalSecret v1alpha1.InfisicalSecret, resourceVariables util.ResourceVariables) {
|
||||
@@ -485,7 +486,6 @@ func (r *InfisicalSecretReconciler) updateResourceVariables(infisicalSecret v1al
|
||||
}
|
||||
|
||||
func (r *InfisicalSecretReconciler) ReconcileInfisicalSecret(ctx context.Context, logger logr.Logger, infisicalSecret *v1alpha1.InfisicalSecret, managedKubeSecretReferences []v1alpha1.ManagedKubeSecretConfig, managedKubeConfigMapReferences []v1alpha1.ManagedKubeConfigMapConfig) (int, error) {
|
||||
|
||||
if infisicalSecret == nil {
|
||||
return 0, fmt.Errorf("infisicalSecret is nil")
|
||||
}
|
||||
@@ -506,9 +506,10 @@ func (r *InfisicalSecretReconciler) ReconcileInfisicalSecret(ctx context.Context
|
||||
}
|
||||
|
||||
r.updateResourceVariables(*infisicalSecret, util.ResourceVariables{
|
||||
InfisicalClient: infisicalClient,
|
||||
CancelCtx: cancelCtx,
|
||||
AuthDetails: authDetails,
|
||||
InfisicalClient: infisicalClient,
|
||||
CancelCtx: cancelCtx,
|
||||
AuthDetails: authDetails,
|
||||
ServerSentEvents: sse.NewConnectionRegistry(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -577,71 +578,116 @@ func (r *InfisicalSecretReconciler) EnsureEventStream(ctx context.Context, logge
|
||||
return fmt.Errorf("infisicalSecret is nil")
|
||||
}
|
||||
|
||||
resourceVariables := r.getResourceVariables(*secret)
|
||||
infiscalClient := resourceVariables.InfisicalClient
|
||||
projectSlug := resourceVariables.AuthDetails.MachineIdentityScope.ProjectSlug
|
||||
variables := r.getResourceVariables(*secret)
|
||||
|
||||
if !variables.AuthDetails.IsMachineIdentityAuth {
|
||||
return fmt.Errorf("only machine identity is supported for subscriptions")
|
||||
}
|
||||
|
||||
projectSlug := variables.AuthDetails.MachineIdentityScope.ProjectSlug
|
||||
secretsPath := variables.AuthDetails.MachineIdentityScope.SecretsPath
|
||||
envSlug := variables.AuthDetails.MachineIdentityScope.EnvSlug
|
||||
|
||||
infiscalClient := variables.InfisicalClient
|
||||
conn := variables.ServerSentEvents
|
||||
|
||||
token := infiscalClient.Auth().GetAccessToken()
|
||||
|
||||
proj, err := util.GetProjectBySlug(token, projectSlug)
|
||||
project, err := util.GetProjectBySlug(token, projectSlug)
|
||||
|
||||
logger.Info("Project", "project", proj, "slug", projectSlug)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get project [err=%s]", err)
|
||||
}
|
||||
|
||||
client := sse.NewClient(fmt.Sprintf("%s/v1/events/subscribe/project-events", api.API_HOST_URL))
|
||||
|
||||
registers := []api.SubProjectEventsRequestRegister{
|
||||
api.SubProjectEventsRequestRegister{
|
||||
Event: "secret:delete",
|
||||
Conditions: &api.SubProjectEventsRequestCondition{
|
||||
SecretPath: "/**",
|
||||
EnvironmentSlug: secret.Spec.Authentication.UniversalAuth.SecretsScope.EnvSlug,
|
||||
},
|
||||
},
|
||||
if variables.AuthDetails.MachineIdentityScope.Recursive {
|
||||
secretsPath = fmt.Sprint(secretsPath, "**")
|
||||
}
|
||||
|
||||
b, err := json.Marshal(api.SubProjectEventsRequest{
|
||||
ProjectID: proj.ID,
|
||||
Register: registers,
|
||||
conditions := &api.SubProjectEventsRequestCondition{
|
||||
SecretPath: secretsPath,
|
||||
EnvironmentSlug: envSlug,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(api.SubProjectEventsRequest{
|
||||
ProjectID: project.ID,
|
||||
Register: []api.SubProjectEventsRequestRegister{
|
||||
{
|
||||
Event: "secret:create",
|
||||
Conditions: conditions,
|
||||
},
|
||||
{
|
||||
Event: "secret:update",
|
||||
Conditions: conditions,
|
||||
},
|
||||
{
|
||||
Event: "secret:delete",
|
||||
Conditions: conditions,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("CallSubscribeProjectEvents: Unable to marshal body [err=%s]", err)
|
||||
return fmt.Errorf("CallSubscribeProjectEvents: unable to marshal body [err=%s]", err)
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"User-Agent": api.USER_AGENT_NAME,
|
||||
"Authorization": fmt.Sprint("Bearer ", token),
|
||||
}
|
||||
events, errors, err := conn.Subscribe(func() (*http.Request, error) {
|
||||
headers := map[string]string{
|
||||
"User-Agent": api.USER_AGENT_NAME,
|
||||
"Authorization": fmt.Sprint("Bearer ", token),
|
||||
}
|
||||
|
||||
events, errors, err := client.Connect("POST", headers, strings.NewReader(string(b)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v1/events/subscribe/project-events", api.API_HOST_URL), strings.NewReader(string(body)))
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
return req, err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to connect to SSE server [err=%s]", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to project events [err=%s]", err)
|
||||
return fmt.Errorf("unable to connect to SSE server [err=%s]", err)
|
||||
}
|
||||
|
||||
outer:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
logger.Info("Received event", "event", event)
|
||||
case ev := <-events:
|
||||
logger.Info("Received event", "secret", secret, "event", ev)
|
||||
r.SourceCh <- event.GenericEvent{
|
||||
Object: secret,
|
||||
}
|
||||
logger.Info("Send to channel")
|
||||
case err := <-errors:
|
||||
logger.Error(err, "Error occurred")
|
||||
break outer
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context done")
|
||||
return nil
|
||||
break outer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *InfisicalSecretReconciler) CloseEventStream(ctx context.Context, logger logr.Logger, infisicalSecretCRD *secretsv1alpha1.InfisicalSecret) error {
|
||||
logger.Info("Event watcher disabled")
|
||||
// ensure event watcher is running
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InfisicalSecretReconciler) CloseEventStream(ctx context.Context, logger logr.Logger, secret *secretsv1alpha1.InfisicalSecret) error {
|
||||
logger.Info("Event watcher disabled")
|
||||
|
||||
if secret == nil {
|
||||
return fmt.Errorf("infisicalSecret is nil")
|
||||
}
|
||||
|
||||
variables := r.getResourceVariables(*secret)
|
||||
|
||||
if !variables.AuthDetails.IsMachineIdentityAuth {
|
||||
return fmt.Errorf("only machine identity is supported for subscriptions")
|
||||
}
|
||||
|
||||
conn := variables.ServerSentEvents
|
||||
|
||||
if _, ok := conn.Get(); ok {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Infisical/infisical/k8-operator/packages/model"
|
||||
"github.com/Infisical/infisical/k8-operator/packages/util/sse"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
@@ -170,25 +167,3 @@ func CallGetProjectByIDv2(httpClient *resty.Client, request GetProjectByIDReques
|
||||
return projectResponse, nil
|
||||
|
||||
}
|
||||
|
||||
func CallSubscribeProjectEvents(projectID, token string, body SubProjectEventsRequest) (<-chan sse.SSEEvent, <-chan error, error) {
|
||||
client := sse.NewClient(fmt.Sprintf("%s/v1/events/subscribe/project-events", API_HOST_URL))
|
||||
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("CallSubscribeProjectEvents: Unable to marshal body [err=%s]", err)
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"User-Agent": USER_AGENT_NAME,
|
||||
"Authorization": fmt.Sprint("Bearer ", token),
|
||||
}
|
||||
|
||||
events, errors, err := client.Connect("POST", headers, strings.NewReader(string(b)))
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("CallSubscribeProjectEvents: Unable to connect to SSE server [err=%s]", err)
|
||||
}
|
||||
|
||||
return events, errors, err
|
||||
}
|
||||
|
||||
59
k8-operator/packages/util/handler.go
Normal file
59
k8-operator/packages/util/handler.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/client-go/util/workqueue"
|
||||
"sigs.k8s.io/controller-runtime/pkg/event"
|
||||
"sigs.k8s.io/controller-runtime/pkg/reconcile"
|
||||
)
|
||||
|
||||
// computeMaxJitterDuration returns a random duration between 0 and max.
|
||||
// This is useful for introducing jitter to event processing.
|
||||
func computeMaxJitterDuration(max time.Duration) (time.Duration, time.Duration) {
|
||||
if max <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
jitter := time.Duration(rand.Int63n(int64(max)))
|
||||
return max, jitter
|
||||
}
|
||||
|
||||
// EnqueueDelayedEventHandler enqueues reconcile requests with a random delay (jitter)
|
||||
// to spread the load and avoid thundering herd issues.
|
||||
type EnqueueDelayedEventHandler struct {
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func (e *EnqueueDelayedEventHandler) Create(_ event.CreateEvent, _ workqueue.RateLimitingInterface) {
|
||||
}
|
||||
|
||||
func (e *EnqueueDelayedEventHandler) Update(_ event.UpdateEvent, _ workqueue.RateLimitingInterface) {
|
||||
}
|
||||
|
||||
func (e *EnqueueDelayedEventHandler) Delete(_ event.DeleteEvent, _ workqueue.RateLimitingInterface) {
|
||||
}
|
||||
|
||||
func (e *EnqueueDelayedEventHandler) Generic(evt event.GenericEvent, q workqueue.RateLimitingInterface) {
|
||||
fmt.Println(evt)
|
||||
if evt.Object == nil {
|
||||
return
|
||||
}
|
||||
|
||||
req := reconcile.Request{
|
||||
NamespacedName: types.NamespacedName{
|
||||
Namespace: evt.Object.GetNamespace(),
|
||||
Name: evt.Object.GetName(),
|
||||
},
|
||||
}
|
||||
|
||||
_, delay := computeMaxJitterDuration(e.Delay)
|
||||
|
||||
if delay > 0 {
|
||||
q.AddAfter(req, delay)
|
||||
} else {
|
||||
q.Add(req)
|
||||
}
|
||||
}
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type ResourceVariables struct {
|
||||
InfisicalClient infisicalSdk.InfisicalClientInterface
|
||||
CancelCtx context.CancelFunc
|
||||
AuthDetails AuthenticationDetails
|
||||
EventStreamClient *sse.SSEClient
|
||||
InfisicalClient infisicalSdk.InfisicalClientInterface
|
||||
CancelCtx context.CancelFunc
|
||||
AuthDetails AuthenticationDetails
|
||||
ServerSentEvents sse.ConnectionRegistry
|
||||
}
|
||||
|
||||
116
k8-operator/packages/util/sse/client.go
Normal file
116
k8-operator/packages/util/sse/client.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SSEEvent struct {
|
||||
ID string
|
||||
Event string
|
||||
Data string
|
||||
}
|
||||
|
||||
// SSEClient handles SSE connections
|
||||
type SSEClient struct {
|
||||
URL string
|
||||
Client *http.Client
|
||||
LastHealthCheck time.Time
|
||||
mu *sync.Mutex // for safe concurrent access to LastHealthCheck
|
||||
}
|
||||
|
||||
// NewClient creates a new SSE client
|
||||
func NewClient() SSEClient {
|
||||
return SSEClient{
|
||||
mu: &sync.Mutex{},
|
||||
Client: &http.Client{
|
||||
Timeout: 0, // No timeout for streaming
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes SSE connection and returns a channel of events
|
||||
func (c *SSEClient) Connect(req *http.Request) (<-chan SSEEvent, <-chan error, error) {
|
||||
// Set required headers for SSE
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
eventChan := make(chan SSEEvent)
|
||||
errorChan := make(chan error)
|
||||
|
||||
go c.stream(resp.Body, eventChan, errorChan)
|
||||
|
||||
return eventChan, errorChan, nil
|
||||
}
|
||||
|
||||
func (c *SSEClient) stream(body io.ReadCloser, eventChan chan<- SSEEvent, errorChan chan<- error) {
|
||||
defer body.Close()
|
||||
defer close(eventChan)
|
||||
defer close(errorChan)
|
||||
|
||||
scanner := bufio.NewScanner(body)
|
||||
var event SSEEvent
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// End of event
|
||||
if line == "" {
|
||||
if event.Data != "" || event.Event != "" {
|
||||
if strings.TrimSpace(event.Data) == "1" {
|
||||
c.mu.Lock()
|
||||
c.LastHealthCheck = time.Now()
|
||||
c.mu.Unlock()
|
||||
} else if event.Event != "ping" {
|
||||
eventChan <- event
|
||||
}
|
||||
|
||||
event = SSEEvent{} // Reset for next event
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(line, "data:"):
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if event.Data != "" {
|
||||
event.Data += "\n"
|
||||
}
|
||||
event.Data += data
|
||||
|
||||
case strings.HasPrefix(line, "event:"):
|
||||
event.Event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
|
||||
case strings.HasPrefix(line, "id:"):
|
||||
event.ID = strings.TrimSpace(strings.TrimPrefix(line, "id:"))
|
||||
|
||||
case strings.HasPrefix(line, "retry:"):
|
||||
// Optional: parse and apply retry interval here
|
||||
|
||||
case strings.HasPrefix(line, ":"):
|
||||
// Comment line — ignored
|
||||
default:
|
||||
// Unknown line format — can log/debug if needed
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
errorChan <- err
|
||||
}
|
||||
}
|
||||
@@ -1,109 +1,137 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SSEEvent struct {
|
||||
ID string
|
||||
Event string
|
||||
Data string
|
||||
// ConnectionMeta holds metadata about an active SSE connection
|
||||
type ConnectionMeta struct {
|
||||
EventChan <-chan SSEEvent
|
||||
ErrorChan <-chan error
|
||||
Cancel context.CancelFunc
|
||||
LastPingAt time.Time
|
||||
}
|
||||
|
||||
// SSEClient handles SSE connections
|
||||
type SSEClient struct {
|
||||
URL string
|
||||
Client *http.Client
|
||||
// ConnectionRegistry manages a single SSE connection with a shared client
|
||||
type ConnectionRegistry struct {
|
||||
Ticker *time.Ticker
|
||||
meta *ConnectionMeta
|
||||
client SSEClient
|
||||
}
|
||||
|
||||
// NewSSEClient creates a new SSE client
|
||||
func NewClient(url string) *SSEClient {
|
||||
return &SSEClient{
|
||||
URL: url,
|
||||
Client: &http.Client{
|
||||
Timeout: 0, // No timeout for streaming
|
||||
},
|
||||
// NewConnectionRegistry creates a new registry
|
||||
func NewConnectionRegistry() ConnectionRegistry {
|
||||
return ConnectionRegistry{
|
||||
Ticker: time.NewTicker(time.Second * 30),
|
||||
client: NewClient(),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes SSE connection and returns a channel of events
|
||||
func (c *SSEClient) Connect(method string, headers map[string]string, body io.Reader) (<-chan SSEEvent, <-chan error, error) {
|
||||
req, err := http.NewRequest(method, c.URL, body)
|
||||
// GetOrCreate returns existing connection or creates a new one
|
||||
func (r *ConnectionRegistry) GetOrCreate(
|
||||
onBuild func() (*http.Request, error),
|
||||
) (*ConnectionMeta, error) {
|
||||
// First try to get existing connection
|
||||
if r.meta != nil {
|
||||
return r.meta, nil
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
req, err := onBuild()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build request: %w", err)
|
||||
}
|
||||
|
||||
// Add cancellation context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
eventChan, errorChan, err := r.client.Connect(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
meta := &ConnectionMeta{
|
||||
EventChan: eventChan,
|
||||
ErrorChan: errorChan,
|
||||
Cancel: cancel,
|
||||
LastPingAt: time.Now(),
|
||||
}
|
||||
|
||||
r.meta = meta
|
||||
|
||||
// Start cleanup monitor for this connection
|
||||
go r.monitor(ctx, meta)
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// Get retrieves the existing connection
|
||||
func (r *ConnectionRegistry) Get() (*ConnectionMeta, bool) {
|
||||
return r.meta, r.meta != nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (r *ConnectionRegistry) Close() {
|
||||
if r.meta != nil {
|
||||
r.meta.Cancel()
|
||||
r.meta = nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns whether there's an active connection
|
||||
func (r *ConnectionRegistry) IsConnected() bool {
|
||||
return r.meta != nil
|
||||
}
|
||||
|
||||
// monitorConnection watches for connection closure and cleans up
|
||||
func (r *ConnectionRegistry) monitor(ctx context.Context, meta *ConnectionMeta) {
|
||||
outer:
|
||||
for range r.Ticker.C {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break outer
|
||||
default:
|
||||
if r.IsConnected() && time.Since(r.meta.LastPingAt) > 2*time.Minute {
|
||||
fmt.Println("Last ping was more than 2 minutes ago")
|
||||
r.Close()
|
||||
break outer
|
||||
} else {
|
||||
fmt.Println("Last ping was within the last 2 minutes")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up from registry
|
||||
if r.meta == meta {
|
||||
r.meta = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionInfo provides read-only info about a connection
|
||||
type ConnectionInfo struct {
|
||||
LastPingAt time.Time
|
||||
}
|
||||
|
||||
// Subscribe provides a convenient way to get events from the connection
|
||||
func (r *ConnectionRegistry) Subscribe(
|
||||
onBuild func() (*http.Request, error),
|
||||
) (<-chan SSEEvent, <-chan error, error) {
|
||||
meta, err := r.GetOrCreate(onBuild)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Set required headers for SSE
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
eventChan := make(chan SSEEvent)
|
||||
errorChan := make(chan error)
|
||||
|
||||
go c.readEvents(resp.Body, eventChan, errorChan)
|
||||
|
||||
return eventChan, errorChan, nil
|
||||
return meta.EventChan, meta.ErrorChan, nil
|
||||
}
|
||||
|
||||
// readEvents reads and parses SSE events from the response body
|
||||
func (c *SSEClient) readEvents(body io.ReadCloser, eventChan chan<- SSEEvent, errorChan chan<- error) {
|
||||
defer body.Close()
|
||||
defer close(eventChan)
|
||||
defer close(errorChan)
|
||||
|
||||
scanner := bufio.NewScanner(body)
|
||||
var event SSEEvent
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Empty line indicates end of event
|
||||
if line == "" {
|
||||
if event.Data != "" || event.Event != "" {
|
||||
eventChan <- event
|
||||
event = SSEEvent{} // Reset for next event
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse event fields
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
if event.Data != "" {
|
||||
event.Data += "\n"
|
||||
}
|
||||
event.Data += strings.TrimPrefix(line, "data: ")
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
event.Event = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "id: ") {
|
||||
event.ID = strings.TrimPrefix(line, "id: ")
|
||||
} else if strings.HasPrefix(line, "retry: ") {
|
||||
// Parse retry value (implementation omitted for brevity)
|
||||
} else if strings.HasPrefix(line, ": ") {
|
||||
// Comment line, ignore
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
errorChan <- err
|
||||
}
|
||||
// Reconnect closes existing connection and creates a new one
|
||||
func (r *ConnectionRegistry) Reconnect(
|
||||
onBuild func() (*http.Request, error),
|
||||
) (*ConnectionMeta, error) {
|
||||
r.Close()
|
||||
return r.GetOrCreate(onBuild)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user