Paste: delayed output problem

Author: roma_marks
Mode: factor
Date: Thu, 7 Apr 2011 06:07:31
Plain Text |
! Copyright (C) 2011 Roman Maksymchuk.
! See http://factorcode.org/license.txt for BSD license.
USING: accessors alien.c-types arrays colors.constants combinators combinators.smart fry google.charts grouping io kernel locals math math.functions math.parser math.vectors models namespaces prettyprint random sequences specialized-arrays strings tools.time ui ui.gadgets ui.gadgets.labels ui.gadgets.tracks ui.gadgets.worlds
prettyprint.config ;
IN: data-mining.neural-nets

FROM: alien.c-types => float ;
SPECIALIZED-ARRAY: float

FROM: alien.c-types => int ;
SPECIALIZED-ARRAY: int

TUPLE: net
   { #inputs integer }
   { #outputs integer }
   { #layers integer }
   { weights array }
   { act-funcs array }
   { act-deriv-funcs array }
   { outputs array } ;

SYMBOLS: act-mult dWeights err echo? graph? calc-err? mod-epoch mod-error ;

2.0  act-mult set-global
f not  echo? set-global
f not  graph? set-global
f not  calc-err? set-global
CONSTANT: el-1  { 1.0 }  inline

: act-tanh  ( x -- y )
   act-mult get  *  tanh ; inline

: act-tanh-deriv  ( x -- y )
   act-tanh  sq  1  swap  -  act-mult get  * ; inline

: act-lin  ( x -- y )
   act-mult get  * ; inline

: act-lin-deriv  ( x -- y )
   drop  act-mult get ; inline

: act-logist  ( x -- y )
   act-mult get  *  neg  exp  1  +  2  swap  /  1  - ; inline

: act-logist-deriv  ( x -- y )
   1  over  -  *  2  *  act-mult get  * ; inline

: err-sq  ( seq1 seq2 -- seq )
   v-  dup  v* ; inline

: err-diff  ( seq1 seq2 -- seq )
   v- ; inline

: err-abs  ( seq1 seq2 -- seq )
   v-  vabs ; inline

: net-new  ( layers-seq act-func-seq -- net )
   [ dup  length  1 -  \ act-tanh  <array> ] when-empty
   dup  [ 
      {
         { \ act-tanh   [ \ act-tanh-deriv ] }
         { \ act-lin    [ \ act-lin-deriv ] }
         { \ act-logist [ \ act-logist-deriv ] }
         [ .  "No such activation function!"  throw ]
      } case
   ] map
   net new  [ [ act-deriv-funcs<< ] [ act-funcs<< ] bi ] keep
   over  [ first  >>#inputs ] [ last  >>#outputs ] [ length  1 -  >>#layers ] tri
   swap  2 clump  [
      [ last ] [ first ] bi  1 +  '[ _  <float-array> ] replicate
   ] map  >>weights ; inline

: net-init  ( net -- net )
   dup  weights>>  [
      [
         [ drop  2000  random  1000  -  10000  /f  ] map
      ] map
   ] map  >>weights ; inline

: net-prod  ( input net -- )
   [ el-1  append ] dip
   [
      [ weights>> ] [ act-funcs>> ] bi  [| act  |
         [ v.  act  execute( x -- y ) ] with map
         dup  [ el-1  append ] dip
      ] 2map  nip
   ] keep  outputs<< ; inline

:: lern-backprop  ( input-seq output-seq err-func param-seq epoch net -- )
   output-seq  input-seq  [| inp |
      inp  net  [ net-prod ] [ outputs>>  last ] bi
      err-func  execute( x1 x2 -- y )
      net  [ outputs>>  reverse ]
      [ weights>>  reverse ]
      [ act-deriv-funcs>>  reverse ] tri  [| act  |
         [
            [ act  execute( x -- y )  * ] 2map  dup
         ] dip
         dup  first  length  1 -  0  <array>  [
            but-last-slice  n*v  v+
         ] 2reduce  swap
      ] 3map  nip
      net  outputs>>  but-last-slice  inp  prefix  reverse  swap  param-seq  [
         first3  [ epoch  ^  *  1 ] dip  -  *  :> step
         [ el-1  append ] dip
         [ step  *  v*n ] with map
      ] 3map  reverse
      dWeights get  param-seq  [
         third  :> mom  
         [ mom  v*n  v+ ] 2map
      ] 3map
      dup  dWeights set
      net  weights>>  [ [ v+ ] 2map ] 2map
      net  weights<<
   ] 2each ; inline

:: lern-backprop-batch  ( input-seq output-seq err-func param-seq epoch net -- )
   dWeights get  [ [ length  0  <array> ] map ] map
   output-seq  input-seq  [| inp |
      inp  net  [ net-prod ] [ outputs>>  last ] bi
      err-func  execute( x1 x2 -- y )
      net  [ outputs>>  reverse ]
      [ weights>>  reverse ]
      [ act-deriv-funcs>>  reverse ] tri  [| act  |
         [
            [ act  execute( x -- y )  * ] 2map  dup
         ] dip
         dup  first  length  1 -  0  <array>  [
            but-last-slice  n*v  v+
         ] 2reduce  swap
      ] 3map  nip
      net  outputs>>  but-last-slice  inp  prefix  reverse  swap  [
         [ el-1  append ] dip
         [ v*n ] with map
      ] 2map  reverse
      [ [ v+ ] 2map ] 2map
   ] 2each
   dWeights get  param-seq  [
      first3  :> mom  epoch  ^  *  1  mom  -  *  :> step
      [
         [ step  v*n ] [ mom  v*n ] bi*  v+
      ] 2map
   ] 3map  dup  dWeights set
   net  weights>>  [ [ v+ ] 2map ] 2map
   net  weights<< ; inline

:: calc-rmse-error  ( input-seq output-seq net -- error )
   0
   output-seq  input-seq  [
      net  [ net-prod ] [ outputs>>  last ] bi  err-sq  sum  +
   ] 2each
   sqrt  output-seq  length  /  net  #outputs>>  /  dup  err set ; inline

:: calc-class-error  ( input-seq output-seq net -- error )
   net  #outputs>>  dup  '[ _  <int-array> ] replicate  err set
   0
   output-seq  input-seq  [
      net  [ net-prod  [ supremum ] [ index ] bi ]
      [ outputs>>  last  [ supremum ] [ index ] bi ] bi
      2dup  swap  err get  nth  [ nth  1 +  ] [ set-nth ] 2bi
      =  not  [ 1 + ] when
   ] 2each
   output-seq  length  /f ; inline
   
:: net-learn  ( train-input-seq train-output-seq test-input-seq test-output-seq err-func lern-func param-seq calc-error epochs net -- )
   net  weights>>  [ [ length  0  <array> ] map ] map  dWeights set
   param-seq  dup  length  1  =  [  net  #layers>>  swap  first  <array>  ] when
   :> params
   echo? get  [
      [
         vertical  <track>
         "Net: "  net  #inputs>>  number>string  append
         net  weights>>  [ length  neg  number>string  append ] each
         <label>  f  track-add
         "Activation functions: "  net  act-funcs>>  unparse  append
         <label>  f  track-add
         "Error function:       "  err-func  unparse  append
         <label>  f  track-add
         "Learning function:    "  lern-func  unparse  append
         <label>  f  track-add
         "Learning parameters:  "  params  unparse  append
         <label>  f  track-add
         "Error type:           "  calc-error  unparse  append
         <label>  f  track-add
         " "  <label>  f  track-add
         "Epoch: 0"  <model>  dup  mod-epoch set
         <label-control>  f  track-add
         "Error: -"  <model>  dup  mod-error set
         <label-control>  f  track-add
         world-attributes new
         "Learning run"  >>title
         open-window
      ] with-ui
   ] when
   calc-err? get  [
      epochs  iota  [|  epoch  |
         train-input-seq  train-output-seq  err-func  params  epoch  1 +  net
         lern-func  execute( x1 x2 x3 x4 x5 x6 -- )
         train-input-seq  train-output-seq  net  calc-error  execute( x1 x2 x3 -- y )
         test-input-seq  [ test-output-seq  net  calc-error  execute( x1 x2 x3 -- y )  2array ] unless-empty
         echo? get  [
            "Epoch: "  epoch  1 +  number>string  append  mod-epoch get  set-model
            "Error: "  err get
            test-input-seq  empty?  [
               net  #outputs>>  iota  swap  0  [ nth  + ] 2reduce
               test-input-seq  length  /f  1  swap  -
            ] unless
            number>string  append  mod-error get  set-model
         ] when
      ] map
      graph? get  [
         test-input-seq  empty?  [
            dup  supremum  :> ymax
            "Min. train error = "  write  dup  infimum  .
            "Last train error = "  write  dup  last  .
            err get  .
            <line> ! <sparkline>
            100  >>height
            500  >>width
            { "x" "y" }  >>axis
            { { 0  0  epochs } { 1  0  ymax } }  >>axis-range
            COLOR: white  >>background
            COLOR: DodgerBlue  >>foreground
            post-chart.
         ] [ 
            [ [ first ] map ] [ [ second ] map ] bi
            dup  supremum  :> ymax
            [
               "Min. train error = "  write  dup  infimum  .
               "Last train error = "  write  dup  last  .
            ] dip
            "Min. test error = "  write  dup  infimum  .
            "Last test error = "  write  dup  last  .
            err get  .
            <2lines>   ! <sparkline>
            100  >>height
            500  >>width
            { "x" "y" }  >>axis
            { { 0  0  epochs } { 1  0  ymax } }  >>axis-range
            COLOR: white  >>background
            COLOR: DodgerBlue  COLOR: MediumSeaGreen  2array  >>foreground
            post-chart.
         ] if
      ] [ drop ] if
   ] [
      epochs  iota  [
         [ train-input-seq  train-output-seq  err-func  params ] dip  1 +  net
         lern-func  execute( x1 x2 x3 x4 x5 x6 -- )
      ] each
      train-input-seq  train-output-seq  net  calc-error  execute( x1 x2 x3 -- y )
      echo? get  [ "Last train error = "  write  . ] [ drop ] if
      test-input-seq  [
         test-output-seq  net  calc-error  execute( x1 x2 x3 -- y )
         echo? get  [ "Last test error = "  write  . ] [ drop ] if
!         net  #outputs>>  iota  swap  0  [ nth  + ] 2reduce
!         test-input-seq  length  /f  1  swap  -
      ] unless-empty
   ] if ; inline

:: (cross-valid)  ( folds input-seq output-seq err-func lern-func param-seq calc-error epochs net -- aver-error )
   output-seq  length  folds  /i  :> size
   folds  1 -  :> lst
   folds  iota  [
      [ {
         { 0  [ input-seq  size  cut-slice  swap
            [ output-seq  size  cut-slice  swap ] dip  swap ] }
         { lst  [ input-seq  lst  size  *  cut-slice
            [ output-seq  lst  size  *  cut-slice ] dip  swap ] }
         [ [ input-seq  swap  size  *  cut-slice  size  cut-slice  swap
            [ append ] dip  output-seq ] keep
            size  *  cut-slice  size  cut-slice  swap  [ append  swap ] dip  
         ]
      } case ] keep
      echo?  get  [ "Fold #"  write  1 +  . ] [ drop ] if
      err-func  lern-func  param-seq  calc-error  epochs  net  net-learn  err get
      calc-error  \ calc-class-error  =  [
         net  #outputs>>  iota  swap  0  [ nth  + ] 2reduce  size  /f  1  swap  -
      ] when
   ] [ + ] map-reduce  folds  / ; inline

: cross-valid  ( folds input-seq output-seq err-func lern-func param-seq calc-error epochs net -- )
   (cross-valid)  "Average test error = "  write  number>string  print ; inline

New Annotation

Summary:
Author:
Mode:
Body: