diff --git a/config.go b/config.go index 3f188f0..57eef5f 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,10 @@ -// TODO: Check permissions on file -// TODO: Create dir if needed -// TODO: Dont' fail hard on failed read package main import ( "bufio" "io/ioutil" "os" + "path" "github.com/naoina/toml" "github.com/op/go-logging" @@ -28,18 +26,25 @@ type FirewallConfigs struct { var FirewallConfig FirewallConfigs -func readConfig() { +func _readConfig(file string) []byte { f, err := os.Open(configDefaultPath) if err != nil { - log.Error(err.Error()) - os.Exit(1) + log.Warning(err.Error()) + return []byte{} } defer f.Close() buf, err := ioutil.ReadAll(f) if err != nil { - log.Error(err.Error()) - os.Exit(1) + log.Warning(err.Error()) + return []byte{} } + + return buf +} + +func readConfig() { + buf := _readConfig(configDefaultPath) + FirewallConfig = FirewallConfigs{ LogLevel: "NOTICE", LoggingLevel: logging.NOTICE, @@ -50,9 +55,11 @@ func readConfig() { DefaultActionId: 1, } - if err := toml.Unmarshal(buf, &FirewallConfig); err != nil { - log.Error(err.Error()) - os.Exit(1) + if len(buf) > 0 { + if err := toml.Unmarshal(buf, &FirewallConfig); err != nil { + log.Error(err.Error()) + os.Exit(1) + } } FirewallConfig.LoggingLevel, _ = logging.LogLevel(FirewallConfig.LogLevel) FirewallConfig.DefaultActionId = valueScope(FirewallConfig.DefaultAction) @@ -62,10 +69,19 @@ func writeConfig() { FirewallConfig.LogLevel = FirewallConfig.LoggingLevel.String() FirewallConfig.DefaultAction = printScope(FirewallConfig.DefaultActionId) + if _, err := os.Stat(path.Dir(configDefaultPath)); err != nil && os.IsNotExist(err) { + if err := os.MkdirAll(path.Dir(configDefaultPath), 0755); err != nil { + log.Error(err.Error()) + //os.Exit(1) + return + } + } + f, err := os.Create(configDefaultPath) if err != nil { log.Error(err.Error()) - os.Exit(1) + //os.Exit(1) + return } defer f.Close() @@ -73,7 +89,8 @@ func writeConfig() { cw := toml.NewEncoder(w) if err := cw.Encode(FirewallConfig); err != nil { log.Error(err.Error()) - os.Exit(1) + //os.Exit(1) + return } w.Flush() }