Files
alpine-router/firewall/firewall.go
2026-04-13 12:40:49 +03:00

202 lines
6.7 KiB
Go

package firewall
import (
"fmt"
"os"
"os/exec"
"strings"
)
const tableName = "alpine-router"
// Rule is a single stateless forward-filter rule.
type Rule struct {
ID string `yaml:"id" json:"id"`
Enabled bool `yaml:"enabled" json:"enabled"`
Action string `yaml:"action" json:"action"` // accept | drop | reject
Protocol string `yaml:"protocol" json:"protocol"` // tcp | udp | icmp | all
SrcAddr string `yaml:"src_addr" json:"src_addr"` // CIDR or IP, empty = any
SrcPort string `yaml:"src_port" json:"src_port"` // "80" | "80-443", empty = any
DstAddr string `yaml:"dst_addr" json:"dst_addr"`
DstPort string `yaml:"dst_port" json:"dst_port"`
InIface string `yaml:"in_iface" json:"in_iface"` // input interface, empty = any
OutIface string `yaml:"out_iface" json:"out_iface"` // output interface, empty = any
Comment string `yaml:"comment" json:"comment"`
}
// Config is the top-level firewall config stored in config.yaml.
type Config struct {
Rules []Rule `yaml:"rules" json:"rules"`
VLANIsolation bool `yaml:"vlan_isolation" json:"vlan_isolation"`
}
// NATConfig holds NAT masquerade settings (passed to avoid a direct nat import).
type NATConfig struct {
Interfaces []string
}
// IsInstalled reports whether the nft binary is available.
func IsInstalled() bool {
_, err := exec.LookPath("nft")
return err == nil
}
// ApplyAll atomically regenerates the complete nftables ruleset:
// - NAT masquerade for natCfg.Interfaces
// - Blocked client IP drops
// - User rules from fwCfg (in order, enabled only)
// - LAN isolation (if fwCfg.VLANIsolation): blocks traffic between any two LAN interfaces
// (native + tagged VLANs). User rules placed above have priority.
// - Default accept from LAN interfaces to WAN
//
// lanIfaces is the union of NAT interfaces and all VLAN interfaces — every interface
// that serves a local subnet. Isolation prevents any two of them from talking directly.
func ApplyAll(natCfg NATConfig, fwCfg Config, blockedIPs, lanIfaces []string) error {
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1"), 0644); err != nil {
return fmt.Errorf("enable ip_forward: %w", err)
}
// Remove both old and new table names to ensure clean state.
exec.Command("nft", "delete", "table", "ip", "alpine-router-nat").Run()
exec.Command("nft", "delete", "table", "ip", tableName).Run()
var activeRules []Rule
for _, r := range fwCfg.Rules {
if r.Enabled {
activeRules = append(activeRules, r)
}
}
hasNAT := len(natCfg.Interfaces) > 0
hasBlocked := len(blockedIPs) > 0
hasVLANIsolation := fwCfg.VLANIsolation && len(lanIfaces) >= 2
if !hasNAT && !hasBlocked && !hasVLANIsolation && len(activeRules) == 0 {
return nil
}
var sb strings.Builder
fmt.Fprintf(&sb, "table ip %s {\n", tableName)
// ── Forward chain ────────────────────────────────────────────────────────
sb.WriteString(" chain forward {\n")
sb.WriteString(" type filter hook forward priority filter; policy drop;\n")
sb.WriteString(" ct state established,related accept\n")
for _, ip := range blockedIPs {
fmt.Fprintf(&sb, " ip saddr %s drop\n", ip)
fmt.Fprintf(&sb, " ip daddr %s drop\n", ip)
}
for _, rule := range activeRules {
line := buildRuleLine(rule)
if line == "" {
continue
}
if rule.Comment != "" {
fmt.Fprintf(&sb, " # %s\n", rule.Comment)
}
fmt.Fprintf(&sb, " %s\n", line)
}
// LAN isolation — drop traffic between any two local (LAN) interfaces.
// Placed AFTER user rules so explicit allow rules above still take effect.
if hasVLANIsolation {
quoted := make([]string, len(lanIfaces))
for i, v := range lanIfaces {
quoted[i] = fmt.Sprintf("%q", v)
}
set := "{ " + strings.Join(quoted, ", ") + " }"
fmt.Fprintf(&sb, " iifname %s oifname %s drop\n", set, set)
}
// Allow from LAN/VLAN interfaces to WAN (non-VLAN, non-blocked traffic falls through above).
for _, iface := range natCfg.Interfaces {
fmt.Fprintf(&sb, " iifname %q accept\n", iface)
}
sb.WriteString(" }\n")
// ── Postrouting (masquerade) ─────────────────────────────────────────────
if hasNAT {
sb.WriteString(" chain postrouting {\n")
sb.WriteString(" type nat hook postrouting priority srcnat; policy accept;\n")
for _, iface := range natCfg.Interfaces {
fmt.Fprintf(&sb, " iifname %q masquerade\n", iface)
}
sb.WriteString(" }\n")
}
sb.WriteString("}\n")
cmd := exec.Command("nft", "-f", "-")
cmd.Stdin = strings.NewReader(sb.String())
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("nft apply: %s: %w", strings.TrimSpace(string(out)), err)
}
// Flush connection tracking table so existing sessions are re-evaluated
// against the new ruleset. Without this, traffic already tracked as
// "established/related" bypasses new drop rules until the session ends.
flushConntrack()
return nil
}
// flushConntrack clears the kernel connection tracking table so that all traffic
// is re-evaluated against the current nftables ruleset. This is necessary when
// adding new drop/reject rules to prevent previously-established sessions from
// continuing to bypass the new rules via ct state established,related accept.
func flushConntrack() {
// Preferred: conntrack utility (part of conntrack-tools package).
if err := exec.Command("conntrack", "-F").Run(); err == nil {
return
}
// Fallback: write to /proc (available when nf_conntrack module is loaded).
_ = os.WriteFile("/proc/sys/net/netfilter/nf_conntrack_flush", []byte("1"), 0644)
}
// buildRuleLine converts a Rule to a single nftables match+action string.
// Returns "" if the rule has no valid action.
func buildRuleLine(r Rule) string {
if r.Action == "" {
return ""
}
var parts []string
if r.InIface != "" {
parts = append(parts, fmt.Sprintf("iifname %q", r.InIface))
}
if r.OutIface != "" {
parts = append(parts, fmt.Sprintf("oifname %q", r.OutIface))
}
if r.SrcAddr != "" {
parts = append(parts, "ip saddr "+r.SrcAddr)
}
if r.DstAddr != "" {
parts = append(parts, "ip daddr "+r.DstAddr)
}
proto := strings.ToLower(r.Protocol)
switch proto {
case "tcp", "udp":
if r.SrcPort != "" || r.DstPort != "" {
if r.SrcPort != "" {
parts = append(parts, proto+" sport "+r.SrcPort)
}
if r.DstPort != "" {
parts = append(parts, proto+" dport "+r.DstPort)
}
} else {
parts = append(parts, "ip protocol "+proto)
}
case "icmp":
parts = append(parts, "ip protocol icmp")
}
parts = append(parts, r.Action)
return strings.Join(parts, " ")
}