2026-06-04 17:32:11 +07:00

604 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package streams
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"acme-reverseproxy/config"
"github.com/sirupsen/logrus"
)
// StreamManager управляет TCP/UDP стримами
type StreamManager struct {
streams map[string]config.StreamConfig
tcpListener net.Listener
udpConn *net.UDPConn
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Поддержка отдельных портов для каждого домена
tcpListeners map[string]net.Listener
udpConns map[string]*net.UDPConn
}
// NewStreamManager создает новый менеджер стримов
func NewStreamManager(streams map[string]config.StreamConfig) *StreamManager {
ctx, cancel := context.WithCancel(context.Background())
return &StreamManager{
streams: streams,
ctx: ctx,
cancel: cancel,
tcpListeners: make(map[string]net.Listener),
udpConns: make(map[string]*net.UDPConn),
}
}
// Start запускает TCP и UDP серверы
func (sm *StreamManager) Start() error {
if len(sm.streams) == 0 {
logrus.Info("No streams configured, skipping stream manager initialization")
return nil
}
logrus.Infof("Starting stream manager with %d streams", len(sm.streams))
// Запускаем TCP сервер
if err := sm.startTCP(); err != nil {
return fmt.Errorf("failed to start TCP server: %w", err)
}
// Запускаем UDP сервер
if err := sm.startUDP(); err != nil {
sm.tcpListener.Close()
return fmt.Errorf("failed to start UDP server: %w", err)
}
return nil
}
// startTCP запускает TCP сервер на общем порту 10000
// и отдельные порты для каждого домена (10001, 10002, ...)
func (sm *StreamManager) startTCP() error {
// Общий сервер на порту 10000 (требует указания домена в заголовке)
listener, err := net.Listen("tcp", ":10000")
if err != nil {
return fmt.Errorf("failed to listen on port 10000: %w", err)
}
sm.tcpListener = listener
logrus.Info("TCP stream server started on port 10000 (shared mode)")
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
for {
select {
case <-sm.ctx.Done():
logrus.Info("TCP stream server shutting down")
return
default:
}
conn, err := listener.Accept()
if err != nil {
select {
case <-sm.ctx.Done():
return
default:
logrus.Errorf("Failed to accept TCP connection: %v", err)
}
continue
}
sm.wg.Add(1)
go func(c net.Conn) {
defer sm.wg.Done()
defer c.Close()
sm.handleTCPConnection(c, "")
}(conn)
}
}()
// Отдельные порты для каждого домена (прозрачный режим)
port := 10001
for domain, cfg := range sm.streams {
if cfg.Protocol != "tcp" {
continue
}
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
logrus.Warnf("Failed to listen on port %d for domain %s: %v", port, domain, err)
port++
continue
}
sm.tcpListeners[domain] = listener
logrus.Infof("TCP stream server for %s started on port %d (transparent mode)", domain, port)
sm.wg.Add(1)
go func(d string, l net.Listener) {
defer sm.wg.Done()
for {
select {
case <-sm.ctx.Done():
logrus.Infof("TCP stream server for %s shutting down", d)
return
default:
}
conn, err := l.Accept()
if err != nil {
select {
case <-sm.ctx.Done():
return
default:
logrus.Errorf("Failed to accept TCP connection for %s: %v", d, err)
}
continue
}
sm.wg.Add(1)
go func(c net.Conn, domain string) {
defer sm.wg.Done()
defer c.Close()
sm.handleTCPConnection(c, domain)
}(conn, d)
}
}(domain, listener)
port++
}
return nil
}
// startUDP запускает UDP сервер на общем порту 10000
// и отдельные порты для каждого домена (11001, 11002, ...)
func (sm *StreamManager) startUDP() error {
// Общий сервер на порту 10000 (требует указания домена в заголовке)
addr, err := net.ResolveUDPAddr("udp", ":10000")
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP port 10000: %w", err)
}
sm.udpConn = conn
logrus.Info("UDP stream server started on port 10000 (shared mode)")
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
buf := make([]byte, 65536)
for {
select {
case <-sm.ctx.Done():
logrus.Info("UDP stream server shutting down")
return
default:
}
n, addr, err := conn.ReadFromUDP(buf)
if err != nil {
select {
case <-sm.ctx.Done():
return
default:
logrus.Errorf("Failed to read UDP packet: %v", err)
}
continue
}
sm.wg.Add(1)
go func(data []byte, a *net.UDPAddr) {
defer sm.wg.Done()
sm.handleUDPConnection(data, a, "")
}(buf[:n], addr)
}
}()
// Отдельные порты для каждого домена (прозрачный режим)
port := 11001
for domain, cfg := range sm.streams {
if cfg.Protocol != "udp" {
continue
}
addr := fmt.Sprintf(":%d", port)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
logrus.Warnf("Failed to resolve UDP address %s for domain %s: %v", addr, domain, err)
port++
continue
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
logrus.Warnf("Failed to listen on UDP port %d for domain %s: %v", port, domain, err)
port++
continue
}
sm.udpConns[domain] = conn
logrus.Infof("UDP stream server for %s started on port %d (transparent mode)", domain, port)
sm.wg.Add(1)
go func(d string, c *net.UDPConn) {
defer sm.wg.Done()
buf := make([]byte, 65536)
for {
select {
case <-sm.ctx.Done():
logrus.Infof("UDP stream server for %s shutting down", d)
return
default:
}
n, addr, err := c.ReadFromUDP(buf)
if err != nil {
select {
case <-sm.ctx.Done():
return
default:
logrus.Errorf("Failed to read UDP packet for %s: %v", d, err)
}
continue
}
sm.wg.Add(1)
go func(data []byte, a *net.UDPAddr, domain string) {
defer sm.wg.Done()
sm.handleUDPConnection(data, a, domain)
}(buf[:n], addr, d)
}
}(domain, conn)
port++
}
return nil
}
// handleTCPConnection обрабатывает входящее TCP соединение
// Если domain указан, используется прозрачный режим
// Если domain пуст, требуется указание домена в заголовке
func (sm *StreamManager) handleTCPConnection(conn net.Conn, domain string) {
var targetDomain string
var initialData []byte
if domain != "" {
// Прозрачный режим - домен уже известен
targetDomain = domain
logrus.Infof("TCP connection from %s for domain: %s (transparent mode)", conn.RemoteAddr(), targetDomain)
} else {
// Общий режим - читаем заголовок для определения домена
headerBuf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := io.ReadFull(conn, headerBuf)
if err != nil {
logrus.Errorf("Failed to read TCP header: %v", err)
return
}
conn.SetReadDeadline(time.Time{})
targetDomain = extractDomainFromHeader(headerBuf[:n])
if targetDomain == "" {
logrus.Warn("No target domain found in TCP connection header")
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nDomain header required\n"))
return
}
initialData = headerBuf[:n]
logrus.Infof("TCP connection from %s targeting domain: %s", conn.RemoteAddr(), targetDomain)
}
// Находим конфигурацию стрима для этого домена
sm.mu.RLock()
streamConfig, exists := sm.streams[targetDomain]
sm.mu.RUnlock()
if !exists {
logrus.Warnf("No stream configuration found for domain: %s", targetDomain)
conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\nNo stream configured for this domain\n"))
return
}
if streamConfig.Protocol != "tcp" {
logrus.Warnf("Protocol mismatch: domain %s configured for %s, not TCP", targetDomain, streamConfig.Protocol)
conn.Write([]byte("HTTP/1.1 500 Internal Server Error\r\n\r\nProtocol mismatch\n"))
return
}
// Проксируем соединение
sm.proxyTCP(conn, streamConfig.Target, initialData)
}
// handleUDPConnection обрабатывает входящее UDP соединение
// Если domain указан, используется прозрачный режим
// Если domain пуст, требуется указание домена в заголовке
func (sm *StreamManager) handleUDPConnection(data []byte, addr *net.UDPAddr, domain string) {
var targetDomain string
if domain != "" {
// Прозрачный режим - домен уже известен
targetDomain = domain
logrus.Infof("UDP packet from %s for domain: %s (transparent mode)", addr, targetDomain)
} else {
// Общий режим - извлекаем домен из данных
targetDomain = extractDomainFromHeader(data)
if targetDomain == "" {
logrus.Warn("No target domain found in UDP packet")
return
}
logrus.Infof("UDP packet from %s targeting domain: %s", addr, targetDomain)
}
// Находим конфигурацию стрима для этого домена
sm.mu.RLock()
streamConfig, exists := sm.streams[targetDomain]
sm.mu.RUnlock()
if !exists {
logrus.Warnf("No stream configuration found for domain: %s", targetDomain)
return
}
if streamConfig.Protocol != "udp" {
logrus.Warnf("Protocol mismatch: domain %s configured for %s, not UDP", targetDomain, streamConfig.Protocol)
return
}
// Проксируем UDP пакет
sm.proxyUDP(data, addr, streamConfig.Target)
}
// proxyTCP проксирует TCP соединение
func (sm *StreamManager) proxyTCP(conn net.Conn, target string, initialData []byte) {
// Получаем таймаут из конфигурации (по умолчанию 300 секунд)
timeout := 300 * time.Second
if targetConfig, exists := sm.streams[conn.LocalAddr().String()]; exists {
if targetConfig.Timeout > 0 {
timeout = time.Duration(targetConfig.Timeout) * time.Second
}
}
// Разрешаем целевой адрес
targetAddr := target
if _, _, err := net.SplitHostPort(target); err != nil {
// Если порта нет, добавляем дефолтный
targetAddr = net.JoinHostPort(target, "80")
}
backendConn, err := net.DialTimeout("tcp", targetAddr, time.Duration(30)*time.Second)
if err != nil {
logrus.Errorf("Failed to connect to backend %s: %v", targetAddr, err)
conn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\nFailed to connect to backend\n"))
return
}
defer backendConn.Close()
// Устанавливаем таймаут
conn.SetDeadline(time.Now().Add(timeout))
backendConn.SetDeadline(time.Now().Add(timeout))
// Отправляем начальные данные на бэкенд
if len(initialData) > 0 {
if _, err := backendConn.Write(initialData); err != nil {
logrus.Errorf("Failed to write initial data to backend: %v", err)
return
}
}
// Запускаем двустороннюю передачу данных
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
logrus.Errorf("Panic in TCP forward (client->backend): %v", r)
}
}()
io.Copy(backendConn, conn)
}()
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
logrus.Errorf("Panic in TCP forward (backend->client): %v", r)
}
}()
io.Copy(conn, backendConn)
}()
wg.Wait()
logrus.Debugf("TCP proxy connection closed: %s -> %s", conn.RemoteAddr(), targetAddr)
}
// proxyUDP проксирует UDP пакет
func (sm *StreamManager) proxyUDP(data []byte, clientAddr *net.UDPAddr, target string) {
// Разрешаем целевой адрес
targetAddr := target
if _, _, err := net.SplitHostPort(target); err != nil {
targetAddr = net.JoinHostPort(target, "80")
}
backendAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
logrus.Errorf("Failed to resolve backend address %s: %v", targetAddr, err)
return
}
backendConn, err := net.DialUDP("udp", nil, backendAddr)
if err != nil {
logrus.Errorf("Failed to connect to backend %s: %v", targetAddr, err)
return
}
defer backendConn.Close()
// Устанавливаем таймаут
backendConn.SetDeadline(time.Now().Add(30 * time.Second))
// Отправляем данные на бэкенд
if _, err := backendConn.Write(data); err != nil {
logrus.Errorf("Failed to send data to backend: %v", err)
return
}
// Читаем ответ от бэкенда
responseBuf := make([]byte, 65536)
n, err := backendConn.Read(responseBuf)
if err != nil {
logrus.Errorf("Failed to read response from backend: %v", err)
return
}
// Отправляем ответ клиенту
_, err = sm.udpConn.WriteToUDP(responseBuf[:n], clientAddr)
if err != nil {
logrus.Errorf("Failed to send response to client: %v", err)
return
}
logrus.Debugf("UDP proxy packet forwarded: %s -> %s -> %s", clientAddr, targetAddr, clientAddr)
}
// extractDomainFromHeader извлекает домен из заголовка
func extractDomainFromHeader(data []byte) string {
// Ищем первую строку, которая может содержать домен
// Поддерживаем форматы:
// - "DOMAIN example.com" (простой текстовый формат)
// - "CONNECT example.com:443 HTTP/1.1" (HTTP CONNECT)
// - "GET http://example.com/path HTTP/1.1" (HTTP GET)
lines := splitLines(data)
if len(lines) == 0 {
return ""
}
firstLine := string(lines[0])
// Проверяем формат CONNECT
if len(firstLine) > 8 && firstLine[:7] == "CONNECT" {
parts := splitString(firstLine[8:], ' ')
if len(parts) > 0 {
host := parts[0]
// Убираем порт если есть
if h, _, err := net.SplitHostPort(host); err == nil {
return h
}
return host
}
}
// Проверяем формат GET/POST с полным URL
if (len(firstLine) > 3 && (firstLine[:3] == "GET" || firstLine[:3] == "POST")) {
parts := splitString(firstLine, ' ')
if len(parts) > 1 {
urlStr := parts[1]
// Пытаемся извлечь домен из URL
if urlStr[:7] == "http://" || urlStr[:8] == "https://" {
// Убираем схему
urlStr = urlStr[7:]
if urlStr[:8] == "https://" {
urlStr = urlStr[8:]
}
// Берем домен до первого слэша
for i, c := range urlStr {
if c == '/' || c == ':' {
urlStr = urlStr[:i]
break
}
}
return urlStr
}
}
}
// Проверяем простой текстовый формат "DOMAIN <domain>"
if len(firstLine) > 7 && firstLine[:6] == "DOMAIN" {
domain := trimString(firstLine[7:])
return domain
}
return ""
}
// splitLines разбивает данные на строки
func splitLines(data []byte) [][]byte {
var lines [][]byte
start := 0
for i, b := range data {
if b == '\n' || b == '\r' {
lines = append(lines, data[start:i])
start = i + 1
}
}
if start < len(data) {
lines = append(lines, data[start:])
}
return lines
}
// splitString разбивает строку по разделителю
func splitString(s string, sep rune) []string {
var parts []string
start := 0
for i, c := range s {
if c == sep {
parts = append(parts, s[start:i])
start = i + 1
}
}
parts = append(parts, s[start:])
return parts
}
// trimString убирает пробелы и переносы строк
func trimString(s string) string {
start := 0
end := len(s)
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
end--
}
return s[start:end]
}
// Stop останавливает все серверы
func (sm *StreamManager) Stop() {
logrus.Info("Stopping stream manager...")
sm.cancel()
if sm.tcpListener != nil {
sm.tcpListener.Close()
}
if sm.udpConn != nil {
sm.udpConn.Close()
}
for _, listener := range sm.tcpListeners {
listener.Close()
}
for _, conn := range sm.udpConns {
conn.Close()
}
done := make(chan struct{})
go func() {
sm.wg.Wait()
close(done)
}()
select {
case <-done:
logrus.Info("Stream manager stopped")
case <-time.After(10 * time.Second):
logrus.Warn("Stream manager stop timeout")
}
}