commit a3b2c0f: [Minor] Fix stupid torch that uses `print` for logging

Vsevolod Stakhov vsevolod at highsecure.ru
Thu Mar 28 15:56:04 UTC 2019


Author: Vsevolod Stakhov
Date: 2019-03-28 15:53:17 +0000
URL: https://github.com/rspamd/rspamd/commit/a3b2c0f9db42a0b6d4d68d48654367e5b17b892a (HEAD -> master)

[Minor] Fix stupid torch that uses `print` for logging

---
 contrib/lua-torch/nn/StochasticGradient.lua | 11 +++++++----
 src/plugins/lua/neural.lua                  |  6 ++++--
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/contrib/lua-torch/nn/StochasticGradient.lua b/contrib/lua-torch/nn/StochasticGradient.lua
index a060371e8..dc80be1b1 100644
--- a/contrib/lua-torch/nn/StochasticGradient.lua
+++ b/contrib/lua-torch/nn/StochasticGradient.lua
@@ -8,6 +8,9 @@ function StochasticGradient:__init(module, criterion)
    self.module = module
    self.criterion = criterion
    self.verbose = true
+   self.logger = function(s)
+      print(s)
+   end
 end
 
 function StochasticGradient:train(dataset)
@@ -23,7 +26,7 @@ function StochasticGradient:train(dataset)
       end
    end
 
-   print("# StochasticGradient: training")
+   self.logger("# StochasticGradient: training")
 
    while true do
       local currentError = 0
@@ -49,13 +52,13 @@ function StochasticGradient:train(dataset)
       end
 
       if self.verbose then
-         print("# current error = " .. currentError)
+         self.logger("# current error = " .. currentError)
       end
       iteration = iteration + 1
       currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay)
       if self.maxIteration > 0 and iteration > self.maxIteration then
-         print("# StochasticGradient: you have reached the maximum number of iterations")
-         print("# training error = " .. currentError)
+         self.logger("# StochasticGradient: you have reached the maximum number of iterations")
+         self.logger("# training error = " .. currentError)
          break
       end
    end
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index b75adf468..30c4fee0f 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -671,11 +671,13 @@ local function train_ann(rule, _, ev_base, elt, worker)
             trainer.learning_rate = rule.train.learning_rate
             trainer.verbose = false
             trainer.maxIteration = rule.train.max_iterations
-            trainer.hookIteration = function(self, iteration, currentError)
+            trainer.hookIteration = function(_, iteration, currentError)
               rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
                   iteration, currentError)
             end
-
+            trainer.logger = function(s)
+              rspamd_logger.infox(rspamd_config, 'training: %s', s)
+            end
             trainer:train(dataset)
             local out = torch.MemoryFile()
             out:writeObject(rule.anns[elt].ann_train)


More information about the Commits mailing list