mirror of https://github.com/davisking/dlib.git
Made add_prev output a tensor with dimensions that are the max of each of the
dimensions of its inputs rather than always outputting a tensor that has the dimensions of its immediate predecessors.
This commit is contained in:
parent
b9332698fe
commit
911638638d
|
@ -1447,8 +1447,13 @@ namespace dlib
|
||||||
template <typename SUBNET>
|
template <typename SUBNET>
|
||||||
void forward(const SUBNET& sub, resizable_tensor& output)
|
void forward(const SUBNET& sub, resizable_tensor& output)
|
||||||
{
|
{
|
||||||
output.copy_size(sub.get_output());
|
auto&& t1 = sub.get_output();
|
||||||
tt::add(output, sub.get_output(), layer<tag>(sub).get_output());
|
auto&& t2 = layer<tag>(sub).get_output();
|
||||||
|
output.set_size(std::max(t1.num_samples(),t2.num_samples()),
|
||||||
|
std::max(t1.k(),t2.k()),
|
||||||
|
std::max(t1.nr(),t2.nr()),
|
||||||
|
std::max(t1.nc(),t2.nc()));
|
||||||
|
tt::add(output, t1, t2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SUBNET>
|
template <typename SUBNET>
|
||||||
|
|
|
@ -1600,7 +1600,13 @@ namespace dlib
|
||||||
what layer to add to the output of the previous layer. The result of this
|
what layer to add to the output of the previous layer. The result of this
|
||||||
addition is output by add_prev_. Finally, the addition happens pointwise
|
addition is output by add_prev_. Finally, the addition happens pointwise
|
||||||
according to 4D tensor arithmetic. If the dimensions don't match then
|
according to 4D tensor arithmetic. If the dimensions don't match then
|
||||||
missing elements are presumed to be equal to 0.
|
missing elements are presumed to be equal to 0. Moreover, each dimension
|
||||||
|
of the output tensor is equal to the maximum dimension of either of the
|
||||||
|
inputs. That is, if the tensors A and B are being added to produce C then:
|
||||||
|
- C.num_samples() == max(A.num_samples(), B.num_samples())
|
||||||
|
- C.k() == max(A.k(), B.k())
|
||||||
|
- C.nr() == max(A.nr(), B.nr())
|
||||||
|
- C.nc() == max(A.nc(), B.nc())
|
||||||
!*/
|
!*/
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
Loading…
Reference in New Issue