604 lines
17 KiB
Go
604 lines
17 KiB
Go
package streams
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
|
||
"acme-reverseproxy/config"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
// StreamManager управляет TCP/UDP стримами
|
||
type StreamManager struct {
|
||
streams map[string]config.StreamConfig
|
||
tcpListener net.Listener
|
||
udpConn *net.UDPConn
|
||
mu sync.RWMutex
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
wg sync.WaitGroup
|
||
// Поддержка отдельных портов для каждого домена
|
||
tcpListeners map[string]net.Listener
|
||
udpConns map[string]*net.UDPConn
|
||
}
|
||
|
||
// NewStreamManager создает новый менеджер стримов
|
||
func NewStreamManager(streams map[string]config.StreamConfig) *StreamManager {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
return &StreamManager{
|
||
streams: streams,
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
tcpListeners: make(map[string]net.Listener),
|
||
udpConns: make(map[string]*net.UDPConn),
|
||
}
|
||
}
|
||
|
||
// Start запускает TCP и UDP серверы
|
||
func (sm *StreamManager) Start() error {
|
||
if len(sm.streams) == 0 {
|
||
logrus.Info("No streams configured, skipping stream manager initialization")
|
||
return nil
|
||
}
|
||
|
||
logrus.Infof("Starting stream manager with %d streams", len(sm.streams))
|
||
|
||
// Запускаем TCP сервер
|
||
if err := sm.startTCP(); err != nil {
|
||
return fmt.Errorf("failed to start TCP server: %w", err)
|
||
}
|
||
|
||
// Запускаем UDP сервер
|
||
if err := sm.startUDP(); err != nil {
|
||
sm.tcpListener.Close()
|
||
return fmt.Errorf("failed to start UDP server: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// startTCP запускает TCP сервер на общем порту 10000
|
||
// и отдельные порты для каждого домена (10001, 10002, ...)
|
||
func (sm *StreamManager) startTCP() error {
|
||
// Общий сервер на порту 10000 (требует указания домена в заголовке)
|
||
listener, err := net.Listen("tcp", ":10000")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to listen on port 10000: %w", err)
|
||
}
|
||
|
||
sm.tcpListener = listener
|
||
logrus.Info("TCP stream server started on port 10000 (shared mode)")
|
||
|
||
sm.wg.Add(1)
|
||
go func() {
|
||
defer sm.wg.Done()
|
||
for {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
logrus.Info("TCP stream server shutting down")
|
||
return
|
||
default:
|
||
}
|
||
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
return
|
||
default:
|
||
logrus.Errorf("Failed to accept TCP connection: %v", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
sm.wg.Add(1)
|
||
go func(c net.Conn) {
|
||
defer sm.wg.Done()
|
||
defer c.Close()
|
||
sm.handleTCPConnection(c, "")
|
||
}(conn)
|
||
}
|
||
}()
|
||
|
||
// Отдельные порты для каждого домена (прозрачный режим)
|
||
port := 10001
|
||
for domain, cfg := range sm.streams {
|
||
if cfg.Protocol != "tcp" {
|
||
continue
|
||
}
|
||
addr := fmt.Sprintf(":%d", port)
|
||
listener, err := net.Listen("tcp", addr)
|
||
if err != nil {
|
||
logrus.Warnf("Failed to listen on port %d for domain %s: %v", port, domain, err)
|
||
port++
|
||
continue
|
||
}
|
||
sm.tcpListeners[domain] = listener
|
||
logrus.Infof("TCP stream server for %s started on port %d (transparent mode)", domain, port)
|
||
|
||
sm.wg.Add(1)
|
||
go func(d string, l net.Listener) {
|
||
defer sm.wg.Done()
|
||
for {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
logrus.Infof("TCP stream server for %s shutting down", d)
|
||
return
|
||
default:
|
||
}
|
||
|
||
conn, err := l.Accept()
|
||
if err != nil {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
return
|
||
default:
|
||
logrus.Errorf("Failed to accept TCP connection for %s: %v", d, err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
sm.wg.Add(1)
|
||
go func(c net.Conn, domain string) {
|
||
defer sm.wg.Done()
|
||
defer c.Close()
|
||
sm.handleTCPConnection(c, domain)
|
||
}(conn, d)
|
||
}
|
||
}(domain, listener)
|
||
port++
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// startUDP запускает UDP сервер на общем порту 10000
|
||
// и отдельные порты для каждого домена (11001, 11002, ...)
|
||
func (sm *StreamManager) startUDP() error {
|
||
// Общий сервер на порту 10000 (требует указания домена в заголовке)
|
||
addr, err := net.ResolveUDPAddr("udp", ":10000")
|
||
if err != nil {
|
||
return fmt.Errorf("failed to resolve UDP address: %w", err)
|
||
}
|
||
|
||
conn, err := net.ListenUDP("udp", addr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to listen on UDP port 10000: %w", err)
|
||
}
|
||
|
||
sm.udpConn = conn
|
||
logrus.Info("UDP stream server started on port 10000 (shared mode)")
|
||
|
||
sm.wg.Add(1)
|
||
go func() {
|
||
defer sm.wg.Done()
|
||
buf := make([]byte, 65536)
|
||
for {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
logrus.Info("UDP stream server shutting down")
|
||
return
|
||
default:
|
||
}
|
||
|
||
n, addr, err := conn.ReadFromUDP(buf)
|
||
if err != nil {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
return
|
||
default:
|
||
logrus.Errorf("Failed to read UDP packet: %v", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
sm.wg.Add(1)
|
||
go func(data []byte, a *net.UDPAddr) {
|
||
defer sm.wg.Done()
|
||
sm.handleUDPConnection(data, a, "")
|
||
}(buf[:n], addr)
|
||
}
|
||
}()
|
||
|
||
// Отдельные порты для каждого домена (прозрачный режим)
|
||
port := 11001
|
||
for domain, cfg := range sm.streams {
|
||
if cfg.Protocol != "udp" {
|
||
continue
|
||
}
|
||
addr := fmt.Sprintf(":%d", port)
|
||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||
if err != nil {
|
||
logrus.Warnf("Failed to resolve UDP address %s for domain %s: %v", addr, domain, err)
|
||
port++
|
||
continue
|
||
}
|
||
|
||
conn, err := net.ListenUDP("udp", udpAddr)
|
||
if err != nil {
|
||
logrus.Warnf("Failed to listen on UDP port %d for domain %s: %v", port, domain, err)
|
||
port++
|
||
continue
|
||
}
|
||
|
||
sm.udpConns[domain] = conn
|
||
logrus.Infof("UDP stream server for %s started on port %d (transparent mode)", domain, port)
|
||
|
||
sm.wg.Add(1)
|
||
go func(d string, c *net.UDPConn) {
|
||
defer sm.wg.Done()
|
||
buf := make([]byte, 65536)
|
||
for {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
logrus.Infof("UDP stream server for %s shutting down", d)
|
||
return
|
||
default:
|
||
}
|
||
|
||
n, addr, err := c.ReadFromUDP(buf)
|
||
if err != nil {
|
||
select {
|
||
case <-sm.ctx.Done():
|
||
return
|
||
default:
|
||
logrus.Errorf("Failed to read UDP packet for %s: %v", d, err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
sm.wg.Add(1)
|
||
go func(data []byte, a *net.UDPAddr, domain string) {
|
||
defer sm.wg.Done()
|
||
sm.handleUDPConnection(data, a, domain)
|
||
}(buf[:n], addr, d)
|
||
}
|
||
}(domain, conn)
|
||
port++
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// handleTCPConnection обрабатывает входящее TCP соединение
|
||
// Если domain указан, используется прозрачный режим
|
||
// Если domain пуст, требуется указание домена в заголовке
|
||
func (sm *StreamManager) handleTCPConnection(conn net.Conn, domain string) {
|
||
var targetDomain string
|
||
var initialData []byte
|
||
|
||
if domain != "" {
|
||
// Прозрачный режим - домен уже известен
|
||
targetDomain = domain
|
||
logrus.Infof("TCP connection from %s for domain: %s (transparent mode)", conn.RemoteAddr(), targetDomain)
|
||
} else {
|
||
// Общий режим - читаем заголовок для определения домена
|
||
headerBuf := make([]byte, 1024)
|
||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||
n, err := io.ReadFull(conn, headerBuf)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to read TCP header: %v", err)
|
||
return
|
||
}
|
||
conn.SetReadDeadline(time.Time{})
|
||
|
||
targetDomain = extractDomainFromHeader(headerBuf[:n])
|
||
if targetDomain == "" {
|
||
logrus.Warn("No target domain found in TCP connection header")
|
||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nDomain header required\n"))
|
||
return
|
||
}
|
||
initialData = headerBuf[:n]
|
||
logrus.Infof("TCP connection from %s targeting domain: %s", conn.RemoteAddr(), targetDomain)
|
||
}
|
||
|
||
// Находим конфигурацию стрима для этого домена
|
||
sm.mu.RLock()
|
||
streamConfig, exists := sm.streams[targetDomain]
|
||
sm.mu.RUnlock()
|
||
|
||
if !exists {
|
||
logrus.Warnf("No stream configuration found for domain: %s", targetDomain)
|
||
conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\nNo stream configured for this domain\n"))
|
||
return
|
||
}
|
||
|
||
if streamConfig.Protocol != "tcp" {
|
||
logrus.Warnf("Protocol mismatch: domain %s configured for %s, not TCP", targetDomain, streamConfig.Protocol)
|
||
conn.Write([]byte("HTTP/1.1 500 Internal Server Error\r\n\r\nProtocol mismatch\n"))
|
||
return
|
||
}
|
||
|
||
// Проксируем соединение
|
||
sm.proxyTCP(conn, streamConfig.Target, initialData)
|
||
}
|
||
|
||
// handleUDPConnection обрабатывает входящее UDP соединение
|
||
// Если domain указан, используется прозрачный режим
|
||
// Если domain пуст, требуется указание домена в заголовке
|
||
func (sm *StreamManager) handleUDPConnection(data []byte, addr *net.UDPAddr, domain string) {
|
||
var targetDomain string
|
||
|
||
if domain != "" {
|
||
// Прозрачный режим - домен уже известен
|
||
targetDomain = domain
|
||
logrus.Infof("UDP packet from %s for domain: %s (transparent mode)", addr, targetDomain)
|
||
} else {
|
||
// Общий режим - извлекаем домен из данных
|
||
targetDomain = extractDomainFromHeader(data)
|
||
if targetDomain == "" {
|
||
logrus.Warn("No target domain found in UDP packet")
|
||
return
|
||
}
|
||
logrus.Infof("UDP packet from %s targeting domain: %s", addr, targetDomain)
|
||
}
|
||
|
||
// Находим конфигурацию стрима для этого домена
|
||
sm.mu.RLock()
|
||
streamConfig, exists := sm.streams[targetDomain]
|
||
sm.mu.RUnlock()
|
||
|
||
if !exists {
|
||
logrus.Warnf("No stream configuration found for domain: %s", targetDomain)
|
||
return
|
||
}
|
||
|
||
if streamConfig.Protocol != "udp" {
|
||
logrus.Warnf("Protocol mismatch: domain %s configured for %s, not UDP", targetDomain, streamConfig.Protocol)
|
||
return
|
||
}
|
||
|
||
// Проксируем UDP пакет
|
||
sm.proxyUDP(data, addr, streamConfig.Target)
|
||
}
|
||
|
||
// proxyTCP проксирует TCP соединение
|
||
func (sm *StreamManager) proxyTCP(conn net.Conn, target string, initialData []byte) {
|
||
// Получаем таймаут из конфигурации (по умолчанию 300 секунд)
|
||
timeout := 300 * time.Second
|
||
if targetConfig, exists := sm.streams[conn.LocalAddr().String()]; exists {
|
||
if targetConfig.Timeout > 0 {
|
||
timeout = time.Duration(targetConfig.Timeout) * time.Second
|
||
}
|
||
}
|
||
|
||
// Разрешаем целевой адрес
|
||
targetAddr := target
|
||
if _, _, err := net.SplitHostPort(target); err != nil {
|
||
// Если порта нет, добавляем дефолтный
|
||
targetAddr = net.JoinHostPort(target, "80")
|
||
}
|
||
|
||
backendConn, err := net.DialTimeout("tcp", targetAddr, time.Duration(30)*time.Second)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to connect to backend %s: %v", targetAddr, err)
|
||
conn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\nFailed to connect to backend\n"))
|
||
return
|
||
}
|
||
defer backendConn.Close()
|
||
|
||
// Устанавливаем таймаут
|
||
conn.SetDeadline(time.Now().Add(timeout))
|
||
backendConn.SetDeadline(time.Now().Add(timeout))
|
||
|
||
// Отправляем начальные данные на бэкенд
|
||
if len(initialData) > 0 {
|
||
if _, err := backendConn.Write(initialData); err != nil {
|
||
logrus.Errorf("Failed to write initial data to backend: %v", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
// Запускаем двустороннюю передачу данных
|
||
var wg sync.WaitGroup
|
||
wg.Add(2)
|
||
|
||
go func() {
|
||
defer wg.Done()
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
logrus.Errorf("Panic in TCP forward (client->backend): %v", r)
|
||
}
|
||
}()
|
||
io.Copy(backendConn, conn)
|
||
}()
|
||
|
||
go func() {
|
||
defer wg.Done()
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
logrus.Errorf("Panic in TCP forward (backend->client): %v", r)
|
||
}
|
||
}()
|
||
io.Copy(conn, backendConn)
|
||
}()
|
||
|
||
wg.Wait()
|
||
logrus.Debugf("TCP proxy connection closed: %s -> %s", conn.RemoteAddr(), targetAddr)
|
||
}
|
||
|
||
// proxyUDP проксирует UDP пакет
|
||
func (sm *StreamManager) proxyUDP(data []byte, clientAddr *net.UDPAddr, target string) {
|
||
// Разрешаем целевой адрес
|
||
targetAddr := target
|
||
if _, _, err := net.SplitHostPort(target); err != nil {
|
||
targetAddr = net.JoinHostPort(target, "80")
|
||
}
|
||
|
||
backendAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to resolve backend address %s: %v", targetAddr, err)
|
||
return
|
||
}
|
||
|
||
backendConn, err := net.DialUDP("udp", nil, backendAddr)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to connect to backend %s: %v", targetAddr, err)
|
||
return
|
||
}
|
||
defer backendConn.Close()
|
||
|
||
// Устанавливаем таймаут
|
||
backendConn.SetDeadline(time.Now().Add(30 * time.Second))
|
||
|
||
// Отправляем данные на бэкенд
|
||
if _, err := backendConn.Write(data); err != nil {
|
||
logrus.Errorf("Failed to send data to backend: %v", err)
|
||
return
|
||
}
|
||
|
||
// Читаем ответ от бэкенда
|
||
responseBuf := make([]byte, 65536)
|
||
n, err := backendConn.Read(responseBuf)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to read response from backend: %v", err)
|
||
return
|
||
}
|
||
|
||
// Отправляем ответ клиенту
|
||
_, err = sm.udpConn.WriteToUDP(responseBuf[:n], clientAddr)
|
||
if err != nil {
|
||
logrus.Errorf("Failed to send response to client: %v", err)
|
||
return
|
||
}
|
||
|
||
logrus.Debugf("UDP proxy packet forwarded: %s -> %s -> %s", clientAddr, targetAddr, clientAddr)
|
||
}
|
||
|
||
// extractDomainFromHeader извлекает домен из заголовка
|
||
func extractDomainFromHeader(data []byte) string {
|
||
// Ищем первую строку, которая может содержать домен
|
||
// Поддерживаем форматы:
|
||
// - "DOMAIN example.com" (простой текстовый формат)
|
||
// - "CONNECT example.com:443 HTTP/1.1" (HTTP CONNECT)
|
||
// - "GET http://example.com/path HTTP/1.1" (HTTP GET)
|
||
|
||
lines := splitLines(data)
|
||
if len(lines) == 0 {
|
||
return ""
|
||
}
|
||
|
||
firstLine := string(lines[0])
|
||
|
||
// Проверяем формат CONNECT
|
||
if len(firstLine) > 8 && firstLine[:7] == "CONNECT" {
|
||
parts := splitString(firstLine[8:], ' ')
|
||
if len(parts) > 0 {
|
||
host := parts[0]
|
||
// Убираем порт если есть
|
||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||
return h
|
||
}
|
||
return host
|
||
}
|
||
}
|
||
|
||
// Проверяем формат GET/POST с полным URL
|
||
if (len(firstLine) > 3 && (firstLine[:3] == "GET" || firstLine[:3] == "POST")) {
|
||
parts := splitString(firstLine, ' ')
|
||
if len(parts) > 1 {
|
||
urlStr := parts[1]
|
||
// Пытаемся извлечь домен из URL
|
||
if urlStr[:7] == "http://" || urlStr[:8] == "https://" {
|
||
// Убираем схему
|
||
urlStr = urlStr[7:]
|
||
if urlStr[:8] == "https://" {
|
||
urlStr = urlStr[8:]
|
||
}
|
||
// Берем домен до первого слэша
|
||
for i, c := range urlStr {
|
||
if c == '/' || c == ':' {
|
||
urlStr = urlStr[:i]
|
||
break
|
||
}
|
||
}
|
||
return urlStr
|
||
}
|
||
}
|
||
}
|
||
|
||
// Проверяем простой текстовый формат "DOMAIN <domain>"
|
||
if len(firstLine) > 7 && firstLine[:6] == "DOMAIN" {
|
||
domain := trimString(firstLine[7:])
|
||
return domain
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// splitLines разбивает данные на строки
|
||
func splitLines(data []byte) [][]byte {
|
||
var lines [][]byte
|
||
start := 0
|
||
for i, b := range data {
|
||
if b == '\n' || b == '\r' {
|
||
lines = append(lines, data[start:i])
|
||
start = i + 1
|
||
}
|
||
}
|
||
if start < len(data) {
|
||
lines = append(lines, data[start:])
|
||
}
|
||
return lines
|
||
}
|
||
|
||
// splitString разбивает строку по разделителю
|
||
func splitString(s string, sep rune) []string {
|
||
var parts []string
|
||
start := 0
|
||
for i, c := range s {
|
||
if c == sep {
|
||
parts = append(parts, s[start:i])
|
||
start = i + 1
|
||
}
|
||
}
|
||
parts = append(parts, s[start:])
|
||
return parts
|
||
}
|
||
|
||
// trimString убирает пробелы и переносы строк
|
||
func trimString(s string) string {
|
||
start := 0
|
||
end := len(s)
|
||
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
|
||
start++
|
||
}
|
||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
|
||
end--
|
||
}
|
||
return s[start:end]
|
||
}
|
||
|
||
// Stop останавливает все серверы
|
||
func (sm *StreamManager) Stop() {
|
||
logrus.Info("Stopping stream manager...")
|
||
sm.cancel()
|
||
if sm.tcpListener != nil {
|
||
sm.tcpListener.Close()
|
||
}
|
||
if sm.udpConn != nil {
|
||
sm.udpConn.Close()
|
||
}
|
||
for _, listener := range sm.tcpListeners {
|
||
listener.Close()
|
||
}
|
||
for _, conn := range sm.udpConns {
|
||
conn.Close()
|
||
}
|
||
done := make(chan struct{})
|
||
go func() {
|
||
sm.wg.Wait()
|
||
close(done)
|
||
}()
|
||
select {
|
||
case <-done:
|
||
logrus.Info("Stream manager stopped")
|
||
case <-time.After(10 * time.Second):
|
||
logrus.Warn("Stream manager stop timeout")
|
||
}
|
||
}
|