Skip to content

Commit

Permalink
Merge pull request #13 from remicres/model_fixes_2
Browse files Browse the repository at this point in the history
FIX: normalization of input for image summary
  • Loading branch information
remicres committed Mar 13, 2021
2 parents c4a6cbf + af1c902 commit 8c82b8c
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
import otbtf
import logging
from ops import blur2d, downscale2d
from ops import downscale2d
from vgg import compute_vgg_loss
import network
import constants
Expand Down Expand Up @@ -89,24 +89,23 @@ def main(unused_argv):
iterator_init = iterator.make_initializer(tf_ds)
dataset_inputs = iterator.get_next()

# placeholders with normalization
def _get_normalized_input(key, scale, name):
# model inputs
def _get_input(key, name):
default_input = dataset_inputs[key]
shape = (None, None, None, ds.output_shapes[key][-1])
ph = tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name)
return scale * ph
return tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name)

lr_image = _get_normalized_input(constants.lr_key, params.lr_scale, constants.lr_input_name)
hr_image = _get_normalized_input(constants.hr_key, params.hr_scale, constants.hr_input_name)
lr_image = _get_input(constants.lr_key, constants.lr_input_name)
hr_image = _get_input(constants.hr_key, constants.hr_input_name)

# model
hr_nch = ds.output_shapes[constants.hr_key][-1]
generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch,
nresblocks=params.nresblocks, dim=params.depth)
discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth)

hr_images_real = {factor: downscale2d(hr_image, factor=factor) for factor in constants.factors}
hr_images_fake = generator(lr_image)
hr_images_real = {factor: params.hr_scale * downscale2d(hr_image, factor=factor) for factor in constants.factors}
hr_images_fake = generator(params.lr_scale * lr_image)

# model outputs
gen = {factor: (1.0 / params.hr_scale) * hr_images_fake[factor] for factor in constants.factors}
Expand Down

0 comments on commit 8c82b8c

Please sign in to comment.