Blender  V2.93
multi_function_network_optimization.cc
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or
3  * modify it under the terms of the GNU General Public License
4  * as published by the Free Software Foundation; either version 2
5  * of the License, or (at your option) any later version.
6  *
7  * This program is distributed in the hope that it will be useful,
8  * but WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10  * GNU General Public License for more details.
11  *
12  * You should have received a copy of the GNU General Public License
13  * along with this program; if not, write to the Free Software Foundation,
14  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15  */
16 
21 /* Used to check if two multi-functions have the exact same type. */
22 #include <typeinfo>
23 
27 
28 #include "BLI_disjoint_set.hh"
29 #include "BLI_ghash.h"
30 #include "BLI_map.hh"
31 #include "BLI_multi_value_map.hh"
32 #include "BLI_rand.h"
33 #include "BLI_stack.hh"
34 
36 
37 /* -------------------------------------------------------------------- */
41 static bool set_tag_and_check_if_modified(bool &tag, bool new_value)
42 {
43  if (tag != new_value) {
44  tag = new_value;
45  return true;
46  }
47 
48  return false;
49 }
50 
52 {
53  Array<bool> is_to_the_left(network.node_id_amount(), false);
54  Stack<MFNode *> nodes_to_check;
55 
56  for (MFNode *node : nodes) {
57  is_to_the_left[node->id()] = true;
58  nodes_to_check.push(node);
59  }
60 
61  while (!nodes_to_check.is_empty()) {
62  MFNode &node = *nodes_to_check.pop();
63 
64  for (MFInputSocket *input_socket : node.inputs()) {
65  MFOutputSocket *origin = input_socket->origin();
66  if (origin != nullptr) {
67  MFNode &origin_node = origin->node();
68  if (set_tag_and_check_if_modified(is_to_the_left[origin_node.id()], true)) {
69  nodes_to_check.push(&origin_node);
70  }
71  }
72  }
73  }
74 
75  return is_to_the_left;
76 }
77 
79 {
80  Array<bool> is_to_the_right(network.node_id_amount(), false);
81  Stack<MFNode *> nodes_to_check;
82 
83  for (MFNode *node : nodes) {
84  is_to_the_right[node->id()] = true;
85  nodes_to_check.push(node);
86  }
87 
88  while (!nodes_to_check.is_empty()) {
89  MFNode &node = *nodes_to_check.pop();
90 
91  for (MFOutputSocket *output_socket : node.outputs()) {
92  for (MFInputSocket *target_socket : output_socket->targets()) {
93  MFNode &target_node = target_socket->node();
94  if (set_tag_and_check_if_modified(is_to_the_right[target_node.id()], true)) {
95  nodes_to_check.push(&target_node);
96  }
97  }
98  }
99  }
100 
101  return is_to_the_right;
102 }
103 
105  Span<bool> id_mask,
106  bool mask_value)
107 {
108  Vector<MFNode *> nodes;
109  for (int id : id_mask.index_range()) {
110  if (id_mask[id] == mask_value) {
111  MFNode *node = network.node_or_null_by_id(id);
112  if (node != nullptr) {
113  nodes.append(node);
114  }
115  }
116  }
117  return nodes;
118 }
119 
122 /* -------------------------------------------------------------------- */
130 {
131  Array<bool> node_is_used_mask = mask_nodes_to_the_left(network,
132  network.dummy_nodes().cast<MFNode *>());
133  Vector<MFNode *> nodes_to_remove = find_nodes_based_on_mask(network, node_is_used_mask, false);
134  network.remove(nodes_to_remove);
135 }
136 
139 /* -------------------------------------------------------------------- */
144 {
145  if (node->has_unlinked_inputs()) {
146  return false;
147  }
148  if (node->function().depends_on_context()) {
149  return false;
150  }
151  return true;
152 }
153 
155 {
156  Vector<MFNode *> non_constant_nodes;
157  non_constant_nodes.extend(network.dummy_nodes().cast<MFNode *>());
158 
159  for (MFFunctionNode *node : network.function_nodes()) {
161  non_constant_nodes.append(node);
162  }
163  }
164  return non_constant_nodes;
165 }
166 
168  Span<bool> is_not_constant_mask)
169 {
170  for (MFInputSocket *target_socket : output_socket->targets()) {
171  MFNode &target_node = target_socket->node();
172  bool target_is_not_constant = is_not_constant_mask[target_node.id()];
173  if (target_is_not_constant) {
174  return true;
175  }
176  }
177  return false;
178 }
179 
181 {
182  for (MFInputSocket *target_socket : output_socket->targets()) {
183  if (target_socket->node().is_dummy()) {
184  return target_socket;
185  }
186  }
187  return nullptr;
188 }
189 
191  MFNetwork &network, Vector<MFDummyNode *> &r_temporary_nodes)
192 {
193  Vector<MFNode *> non_constant_nodes = find_non_constant_nodes(network);
194  Array<bool> is_not_constant_mask = mask_nodes_to_the_right(network, non_constant_nodes);
195  Vector<MFNode *> constant_nodes = find_nodes_based_on_mask(network, is_not_constant_mask, false);
196 
197  Vector<MFInputSocket *> sockets_to_compute;
198  for (MFNode *node : constant_nodes) {
199  if (node->inputs().size() == 0) {
200  continue;
201  }
202 
203  for (MFOutputSocket *output_socket : node->outputs()) {
204  MFDataType data_type = output_socket->data_type();
205  if (output_has_non_constant_target_node(output_socket, is_not_constant_mask)) {
206  MFInputSocket *dummy_target = try_find_dummy_target_socket(output_socket);
207  if (dummy_target == nullptr) {
208  dummy_target = &network.add_output("Dummy", data_type);
209  network.add_link(*output_socket, *dummy_target);
210  r_temporary_nodes.append(&dummy_target->node().as_dummy());
211  }
212 
213  sockets_to_compute.append(dummy_target);
214  }
215  }
216  }
217  return sockets_to_compute;
218 }
219 
222  ResourceScope &scope)
223 {
224  for (int param_index : network_fn.param_indices()) {
225  MFParamType param_type = network_fn.param_type(param_index);
226  MFDataType data_type = param_type.data_type();
227 
228  switch (data_type.category()) {
229  case MFDataType::Single: {
230  /* Allocates memory for a single constant folded value. */
231  const CPPType &cpp_type = data_type.single_type();
232  void *buffer = scope.linear_allocator().allocate(cpp_type.size(), cpp_type.alignment());
233  GMutableSpan array{cpp_type, buffer, 1};
234  params.add_uninitialized_single_output(array);
235  break;
236  }
237  case MFDataType::Vector: {
238  /* Allocates memory for a constant folded vector. */
239  const CPPType &cpp_type = data_type.vector_base_type();
240  GVectorArray &vector_array = scope.construct<GVectorArray>(AT, cpp_type, 1);
241  params.add_vector_output(vector_array);
242  break;
243  }
244  }
245  }
246 }
247 
250  ResourceScope &scope,
251  MFNetwork &network)
252 {
253  Array<MFOutputSocket *> folded_sockets{network_fn.param_indices().size(), nullptr};
254 
255  for (int param_index : network_fn.param_indices()) {
256  MFParamType param_type = network_fn.param_type(param_index);
257  MFDataType data_type = param_type.data_type();
258 
259  const MultiFunction *constant_fn = nullptr;
260 
261  switch (data_type.category()) {
262  case MFDataType::Single: {
263  const CPPType &cpp_type = data_type.single_type();
264  GMutableSpan array = params.computed_array(param_index);
265  void *buffer = array.data();
266  scope.add(buffer, array.type().destruct_cb(), AT);
267 
268  constant_fn = &scope.construct<CustomMF_GenericConstant>(AT, cpp_type, buffer);
269  break;
270  }
271  case MFDataType::Vector: {
272  GVectorArray &vector_array = params.computed_vector_array(param_index);
273  GSpan array = vector_array[0];
274  constant_fn = &scope.construct<CustomMF_GenericConstantArray>(AT, array);
275  break;
276  }
277  }
278 
279  MFFunctionNode &folded_node = network.add_function(*constant_fn);
280  folded_sockets[param_index] = &folded_node.output(0);
281  }
282  return folded_sockets;
283 }
284 
286  MFNetwork &network, Span<const MFInputSocket *> sockets_to_compute, ResourceScope &scope)
287 {
288  MFNetworkEvaluator network_fn{{}, sockets_to_compute};
289 
291  MFParamsBuilder params{network_fn, 1};
292  prepare_params_for_constant_folding(network_fn, params, scope);
293  network_fn.call({0}, params, context);
294  return add_constant_folded_sockets(network_fn, params, scope, network);
295 }
296 
297 class MyClass {
298  MFDummyNode node;
299 };
300 
305 {
306  Vector<MFDummyNode *> temporary_nodes;
307  Vector<MFInputSocket *> inputs_to_fold = find_constant_inputs_to_fold(network, temporary_nodes);
308  if (inputs_to_fold.size() == 0) {
309  return;
310  }
311 
313  network, inputs_to_fold, scope);
314 
315  for (int i : inputs_to_fold.index_range()) {
316  MFOutputSocket &original_socket = *inputs_to_fold[i]->origin();
317  network.relink(original_socket, *folded_sockets[i]);
318  }
319 
320  network.remove(temporary_nodes.as_span().cast<MFNode *>());
321 }
322 
325 /* -------------------------------------------------------------------- */
330 {
331  if (node.function().depends_on_context()) {
332  return BLI_rng_get_uint(rng);
333  }
334  if (node.has_unlinked_inputs()) {
335  return BLI_rng_get_uint(rng);
336  }
337 
338  uint64_t combined_inputs_hash = 394659347u;
339  for (MFInputSocket *input_socket : node.inputs()) {
340  MFOutputSocket *origin_socket = input_socket->origin();
341  uint64_t input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()],
342  origin_socket->index());
343  combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash);
344  }
345 
346  uint64_t function_hash = node.function().hash();
347  uint64_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash);
348  return node_hash;
349 }
350 
356 {
357  RNG *rng = BLI_rng_new(0);
358  Array<uint64_t> node_hashes(network.node_id_amount());
359  Array<bool> node_is_hashed(network.node_id_amount(), false);
360 
361  /* No dummy nodes are not assumed to output the same values. */
362  for (MFDummyNode *node : network.dummy_nodes()) {
363  uint64_t node_hash = BLI_rng_get_uint(rng);
364  node_hashes[node->id()] = node_hash;
365  node_is_hashed[node->id()] = true;
366  }
367 
368  Stack<MFFunctionNode *> nodes_to_check;
369  nodes_to_check.push_multiple(network.function_nodes());
370 
371  while (!nodes_to_check.is_empty()) {
372  MFFunctionNode &node = *nodes_to_check.peek();
373  if (node_is_hashed[node.id()]) {
374  nodes_to_check.pop();
375  continue;
376  }
377 
378  /* Make sure that origin nodes are hashed first. */
379  bool all_dependencies_ready = true;
380  for (MFInputSocket *input_socket : node.inputs()) {
381  MFOutputSocket *origin_socket = input_socket->origin();
382  if (origin_socket != nullptr) {
383  MFNode &origin_node = origin_socket->node();
384  if (!node_is_hashed[origin_node.id()]) {
385  all_dependencies_ready = false;
386  nodes_to_check.push(&origin_node.as_function());
387  }
388  }
389  }
390  if (!all_dependencies_ready) {
391  continue;
392  }
393 
394  uint64_t node_hash = compute_node_hash(node, rng, node_hashes);
395  node_hashes[node.id()] = node_hash;
396  node_is_hashed[node.id()] = true;
397  nodes_to_check.pop();
398  }
399 
400  BLI_rng_free(rng);
401  return node_hashes;
402 }
403 
405  Span<uint64_t> node_hashes)
406 {
407  MultiValueMap<uint64_t, MFNode *> nodes_by_hash;
408  for (int id : IndexRange(network.node_id_amount())) {
409  MFNode *node = network.node_or_null_by_id(id);
410  if (node != nullptr) {
411  uint64_t node_hash = node_hashes[id];
412  nodes_by_hash.add(node_hash, node);
413  }
414  }
415  return nodes_by_hash;
416 }
417 
418 static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b)
419 {
420  if (&a == &b) {
421  return true;
422  }
423  if (typeid(a) == typeid(b)) {
424  return a.equals(b);
425  }
426  return false;
427 }
428 
429 static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b)
430 {
431  if (cache.in_same_set(a.id(), b.id())) {
432  return true;
433  }
434 
435  if (a.is_dummy() || b.is_dummy()) {
436  return false;
437  }
438  if (!functions_are_equal(a.as_function().function(), b.as_function().function())) {
439  return false;
440  }
441  for (int i : a.inputs().index_range()) {
442  const MFOutputSocket *origin_a = a.input(i).origin();
443  const MFOutputSocket *origin_b = b.input(i).origin();
444  if (origin_a == nullptr || origin_b == nullptr) {
445  return false;
446  }
447  if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) {
448  return false;
449  }
450  }
451 
452  cache.join(a.id(), b.id());
453  return true;
454 }
455 
456 static void relink_duplicate_nodes(MFNetwork &network,
457  MultiValueMap<uint64_t, MFNode *> &nodes_by_hash)
458 {
459  DisjointSet same_node_cache{network.node_id_amount()};
460 
461  for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) {
462  if (nodes_with_same_hash.size() <= 1) {
463  continue;
464  }
465 
466  Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash;
467  while (nodes_to_check.size() >= 2) {
468  Vector<MFNode *, 16> remaining_nodes;
469 
470  MFNode &deduplicated_node = *nodes_to_check[0];
471  for (MFNode *node : nodes_to_check.as_span().drop_front(1)) {
472  /* This is true with fairly high probability, but hash collisions can happen. So we have to
473  * check if the node actually output the same values. */
474  if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) {
475  for (int i : deduplicated_node.outputs().index_range()) {
476  network.relink(node->output(i), deduplicated_node.output(i));
477  }
478  }
479  else {
480  remaining_nodes.append(node);
481  }
482  }
483  nodes_to_check = std::move(remaining_nodes);
484  }
485  }
486 }
487 
493 {
494  Array<uint64_t> node_hashes = compute_node_hashes(network);
495  MultiValueMap<uint64_t, MFNode *> nodes_by_hash = group_nodes_by_hash(network, node_hashes);
496  relink_duplicate_nodes(network, nodes_by_hash);
497 }
498 
501 } // namespace blender::fn::mf_network_optimization
size_t BLI_ghashutil_combine_hash(size_t hash_a, size_t hash_b)
Random number functions.
void BLI_rng_free(struct RNG *rng) ATTR_NONNULL(1)
Definition: rand.cc:76
unsigned int BLI_rng_get_uint(struct RNG *rng) ATTR_WARN_UNUSED_RESULT ATTR_NONNULL(1)
Definition: rand.cc:104
struct RNG * BLI_rng_new(unsigned int seed)
Definition: rand.cc:54
#define AT
T * data()
Definition: util_array.h:208
void join(int64_t x, int64_t y)
bool in_same_set(int64_t x, int64_t y)
constexpr int64_t size() const
void * allocate(const int64_t size, const int64_t alignment)
void add(const Key &key, const Value &value)
MapType::ValueIterator values() const
T & construct(const char *name, Args &&... args)
LinearAllocator & linear_allocator()
T * add(std::unique_ptr< T > resource, const char *name)
constexpr IndexRange index_range() const
Definition: BLI_span.hh:414
bool is_empty() const
Definition: BLI_stack.hh:321
void push(const T &value)
Definition: BLI_stack.hh:227
void push_multiple(Span< T > values)
Definition: BLI_stack.hh:294
int64_t size() const
Definition: BLI_vector.hh:662
void append(const T &value)
Definition: BLI_vector.hh:438
Span< T > as_span() const
Definition: BLI_vector.hh:340
IndexRange index_range() const
Definition: BLI_vector.hh:887
void extend(Span< T > array)
Definition: BLI_vector.hh:515
int64_t size() const
Definition: FN_cpp_type.hh:280
int64_t alignment() const
Definition: FN_cpp_type.hh:291
const CPPType & vector_base_type() const
const CPPType & single_type() const
const MultiFunction & function() const
Span< MFFunctionNode * > function_nodes()
void relink(MFOutputSocket &old_output, MFOutputSocket &new_output)
Span< MFDummyNode * > dummy_nodes()
void add_link(MFOutputSocket &from, MFInputSocket &to)
MFFunctionNode & add_function(const MultiFunction &function)
MFInputSocket & add_output(StringRef name, MFDataType data_type)
MFOutputSocket & output(int index)
MFInputSocket & input(int index)
Span< MFOutputSocket * > outputs()
MFParamType param_type(int param_index) const
IndexRange param_indices() const
OperationNode * node
uiWidgetBaseParameters params[MAX_WIDGET_BASE_BATCH]
__kernel void ccl_constant KernelData ccl_global void ccl_global char ccl_global int ccl_global char ccl_global unsigned int ccl_global float * buffer
static unsigned a[3]
Definition: RandGen.cpp:92
static Array< uint64_t > compute_node_hashes(MFNetwork &network)
static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b)
static Vector< MFNode * > find_nodes_based_on_mask(MFNetwork &network, Span< bool > id_mask, bool mask_value)
static bool function_node_can_be_constant(MFFunctionNode *node)
static void relink_duplicate_nodes(MFNetwork &network, MultiValueMap< uint64_t, MFNode * > &nodes_by_hash)
void constant_folding(MFNetwork &network, ResourceScope &scope)
static bool set_tag_and_check_if_modified(bool &tag, bool new_value)
static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b)
static Vector< MFNode * > find_non_constant_nodes(MFNetwork &network)
static uint64_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span< uint64_t > node_hashes)
static MFInputSocket * try_find_dummy_target_socket(MFOutputSocket *output_socket)
static Array< bool > mask_nodes_to_the_right(MFNetwork &network, Span< MFNode * > nodes)
static Array< MFOutputSocket * > add_constant_folded_sockets(const MultiFunction &network_fn, MFParamsBuilder &params, ResourceScope &scope, MFNetwork &network)
static Array< MFOutputSocket * > compute_constant_sockets_and_add_folded_nodes(MFNetwork &network, Span< const MFInputSocket * > sockets_to_compute, ResourceScope &scope)
static bool output_has_non_constant_target_node(MFOutputSocket *output_socket, Span< bool > is_not_constant_mask)
static void prepare_params_for_constant_folding(const MultiFunction &network_fn, MFParamsBuilder &params, ResourceScope &scope)
static Vector< MFInputSocket * > find_constant_inputs_to_fold(MFNetwork &network, Vector< MFDummyNode * > &r_temporary_nodes)
static MultiValueMap< uint64_t, MFNode * > group_nodes_by_hash(MFNetwork &network, Span< uint64_t > node_hashes)
static Array< bool > mask_nodes_to_the_left(MFNetwork &network, Span< MFNode * > nodes)
struct SELECTID_Context context
Definition: select_engine.c:47
unsigned __int64 uint64_t
Definition: stdint.h:93
Definition: rand.cc:48