sequence batch & updated priors

parent 2841df2f
......@@ -78,7 +78,7 @@ class ELBO(CriterionContainer):
# decoder parameters
prior = latent_params.get('prior') or None
if prior is not None:
params2 = scale_prior(prior, out['z_enc'])(**kwargs)
params2 = scale_prior(prior, out['z_enc'])(batch_size = out1.shape, **kwargs)
out2 = params2.rsample()
elif out.get('z_params_dec') is not None:
params2 = out['z_params_dec']
......
......@@ -24,7 +24,7 @@ class ShapeError(Exception):
class Selector(object):
"""takes everything"""
def __init__(self):
def __init__(self, **kwargs):
pass
def __repr__(self):
......@@ -55,7 +55,7 @@ class SelectorChain(Selector):
class IndexPick(Selector):
def __init__(self, idx=None, axis=0):
def __init__(self, idx=None, axis=0, **kwargs):
assert idx is not None
self.idx = idx; self.axis = axis
......@@ -74,12 +74,12 @@ class IndexPick(Selector):
offset = (idx * shape[1])*strides[1]
return np.memmap(file, dtype=dtype, mode='r', offset=offset, shape=self.get_shape(shape))
class SequencePick(Selector):
def __init__(self, sequence_length=None, random_idx=True):
def __init__(self, sequence_length=None, random_idx=True, offset=0, **kwargs):
assert sequence_length
self.sequence_length = sequence_length
self.random_idx = random_idx
self.offset = offset
def __repr__(self):
return "SequencePick(%d, random_idx:%s)"%(self.sequence_length, self.random_idx)
......@@ -95,7 +95,9 @@ class SequencePick(Selector):
idx = 0
if self.random_idx:
if not shape[0] - self.sequence_length <= 0:
dx = np.random.randint(shape[0] - self.sequence_length)
idx = np.random.randint(shape[0] - self.sequence_length)
else:
idx = self.offset
offset = (idx * shape[1])*strides[1]
load_shape = shape if shape[0] < self.sequence_length else self.get_shape(shape)
data = np.memmap(file, dtype=dtype, mode='r', offset=offset, shape=load_shape)
......@@ -105,8 +107,58 @@ class SequencePick(Selector):
data = np.pad(np.array(data), pad_width=pads, mode="constant", constant_values=0)
return data
class UnwrapBatchPick(SequencePick):
def get_shape(self, shape):
return shape[1:]
class SequenceBatchPick(Selector):
def __init__(self, sequence_length=None, batches=32, overlap=None, random_idx=True, **kwargs):
assert sequence_length
self.sequence_length = sequence_length
self.random_idx = random_idx
self.batches = batches
self.overlap = overlap or sequence_length//2
def __repr__(self):
if self.random_idx:
return "SequenceBatchPick(%d, random_idx:%s, batches:%s)"%(self.sequence_length, self.random_idx,
self.batches)
else:
return "SequenceBatchPick(%d, random_idx:%s, overlap:%s)"%(self.sequence_length, self.random_idx,
self.overlap)
def __getitem__(self, i):
if self.random_idx:
return SequencePick(self.sequence_length, random_idx=True)
else:
return SequencePick(self.sequence_length, random_idx=False, offset=(i*self.overlap))
def get_shape(self, shape):
return (self.batches, self.sequence_length, *shape[1:])
def __call__(self, file, shape=None, axis=0, dtype=np.float, idx=None, strides=None, **kwargs):
assert axis==0, "memmap indexing on axis > 0 is not implemented yet"
if self.random_idx:
indices = np.array([np.random.randint(shape[0] - self.sequence_length) for _ in range(self.batches)])
else:
indices = np.array([i*self.overlap for i in range(self.batches)])
items = np.zeros(self.get_shape(shape), dtype=dtype)
for i, idx in enumerate(indices):
offset = (idx * shape[1])*strides[1]
load_shape = shape if shape[0] < self.sequence_length else self.get_shape(shape)[1:]
data = np.memmap(file, dtype=dtype, mode='r', offset=offset, shape=load_shape)
if data.shape[0] < self.sequence_length:
pads = [(0,0)]*len(shape)
pads[0] = (0, self.sequence_length - data.shape[0])
data = np.pad(np.array(data), pad_width=pads, mode="constant", constant_values=0)
items[i] = data
return items
class RandomPick(Selector):
def __init__(self, range=None, axis=0):
def __init__(self, range=None, axis=0, **kwargs):
self.range = range
if self.range:
self.range = tuple(self.range)
......@@ -125,8 +177,9 @@ class RandomPick(Selector):
offset = (idx * shape[1])*strides[1]
return np.memmap(file, dtype=dtype, mode='r', offset=offset, shape=self.get_shape(shape))
class RandomRangePick(Selector):
def __init__(self, range=None, length=64, axis=0):
def __init__(self, range=None, length=64, axis=0, **kwargs):
self.range = range; self.length = length
if self.range:
self.range = tuple(self.range)
......@@ -250,7 +303,8 @@ class OfflineEntry(object):
else:
entries = [None]*self.shape[axis]
for i in range(self.shape[axis]):
entries[i] = type(self)(self.file, dtype=self.dtype, shape=self.shape, strides=self.strides, selector=self.selector[i])
#entries[i] = type(self)(self.file, dtype=self.dtype, shape=self.shape, strides=self.strides, selector=self.selector[i])
entries[i] = type(self)(self.file, dtype=self.dtype, shape=self._pre_shape, strides=self.strides, selector=self.selector[i])
return entries
......
......@@ -6,15 +6,18 @@ from . import Distribution
from .distribution_flow import Flow
def IsotropicGaussian(batch_size, device="cpu", requires_grad=False):
def IsotropicGaussian(batch_size=None, device="cpu", requires_grad=False, **kwargs):
assert batch_size
return Normal(zeros(*batch_size, device=device, requires_grad=requires_grad),
ones(*batch_size, device=device, requires_grad=requires_grad))
def WienerProcess(batch_size, device="cpu", requires_grad=False):
def WienerProcess(batch_size=None, device="cpu", requires_grad=False, **kwargs):
assert batch_size
return RandomWalk(zeros(*batch_size, device=device, requires_grad=requires_grad),
ones(*batch_size, device=device, requires_grad=requires_grad))
def IsotropicMultivariateGaussian(batch_size, device="cpu", requires_grad=False):
def IsotropicMultivariateGaussian(batch_size=None, device="cpu", requires_grad=False, **kwargs):
assert batch_size
return MultivariateNormal(zeros(*batch_size, device=device, requires_grad=requires_grad),
covariance_matrix=eye(*batch_size, device=device, requires_grad=requires_grad))
......
......@@ -3,8 +3,8 @@ from torch.distributions import kl
from . import Normal
def get_process_from_normal(normal_dist):
mean = torch.cat([normal_dist.mean[:, 0], normal_dist.mean[:, 1:] - normal_dist.mean[:, -1]], axis=1)
variance = torch.cat([normal_dist.variance[:, 0], normal_dist.variance[:, 1:] - normal_dist.variance[:, -1]], axis=1)
mean = torch.cat([normal_dist.mean[:, 0].unsqueeze(1), normal_dist.mean[:, 1:] - normal_dist.mean[:,:-1]], axis=1)
variance = torch.cat([normal_dist.variance[:, 0].unsqueeze(1), normal_dist.variance[:, 1:] + normal_dist.variance[:,:-1]], axis=1)
return RandomWalk(mean, variance.sqrt())
# Trick described in Bayer & al.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment