Use torch.range instead of torch.linspace because of rounding issues.

For #160
This commit is contained in:
Brandon Amos 2016-07-12 11:45:15 -04:00
parent a0d0ba41fa
commit 6534788a41
No known key found for this signature in database
GPG Key ID: E9B7164CB72D6B6F
2 changed files with 3 additions and 1 deletions

View File

@ -218,7 +218,8 @@ function dataset:__init(...)
if clsLength == 0 then if clsLength == 0 then
error('Class has zero samples: ' .. self.classes[i]) error('Class has zero samples: ' .. self.classes[i])
else else
self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + clsLength, clsLength):long() -- self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + clsLength, clsLength):long()
self.classList[i] = torch.range(runningIndex + 1, runningIndex + clsLength):long()
self.imageClass[{{runningIndex + 1, runningIndex + clsLength}}]:fill(i) self.imageClass[{{runningIndex + 1, runningIndex + clsLength}}]:fill(i)
end end
runningIndex = runningIndex + clsLength runningIndex = runningIndex + clsLength

View File

@ -221,6 +221,7 @@ function dataset:__init(...)
error('Class has zero samples: ' .. self.classes[i]) error('Class has zero samples: ' .. self.classes[i])
else else
self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + clsLength, clsLength):long() self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + clsLength, clsLength):long()
self.classList[i] = torch.range(runningIndex + 1, runningIndex + clsLength):long()
self.imageClass[{{runningIndex + 1, runningIndex + clsLength}}]:fill(i) self.imageClass[{{runningIndex + 1, runningIndex + clsLength}}]:fill(i)
end end
runningIndex = runningIndex + clsLength runningIndex = runningIndex + clsLength