Blender  V2.93
FN_multi_function_network_test.cc
Go to the documentation of this file.
1 /* Apache License, Version 2.0 */
2 
3 #include "testing/testing.h"
4 
8 
9 namespace blender::fn::tests {
10 namespace {
11 
12 TEST(multi_function_network, Test1)
13 {
14  CustomMF_SI_SO<int, int> add_10_fn("add 10", [](int value) { return value + 10; });
15  CustomMF_SI_SI_SO<int, int, int> multiply_fn("multiply", [](int a, int b) { return a * b; });
16 
17  MFNetwork network;
18 
19  MFNode &node1 = network.add_function(add_10_fn);
20  MFNode &node2 = network.add_function(multiply_fn);
21  MFOutputSocket &input_socket = network.add_input("Input", MFDataType::ForSingle<int>());
22  MFInputSocket &output_socket = network.add_output("Output", MFDataType::ForSingle<int>());
23  network.add_link(node1.output(0), node2.input(0));
24  network.add_link(node1.output(0), node2.input(1));
25  network.add_link(node2.output(0), output_socket);
26  network.add_link(input_socket, node1.input(0));
27 
28  MFNetworkEvaluator network_fn{{&input_socket}, {&output_socket}};
29 
30  {
31  Array<int> values = {4, 6, 1, 2, 0};
32  Array<int> results(values.size(), 0);
33 
34  MFParamsBuilder params(network_fn, values.size());
35  params.add_readonly_single_input(values.as_span());
36  params.add_uninitialized_single_output(results.as_mutable_span());
37 
38  MFContextBuilder context;
39 
40  network_fn.call({0, 2, 3, 4}, params, context);
41 
42  EXPECT_EQ(results[0], 14 * 14);
43  EXPECT_EQ(results[1], 0);
44  EXPECT_EQ(results[2], 11 * 11);
45  EXPECT_EQ(results[3], 12 * 12);
46  EXPECT_EQ(results[4], 10 * 10);
47  }
48  {
49  int value = 3;
50  Array<int> results(5, 0);
51 
52  MFParamsBuilder params(network_fn, results.size());
53  params.add_readonly_single_input(&value);
54  params.add_uninitialized_single_output(results.as_mutable_span());
55 
56  MFContextBuilder context;
57 
58  network_fn.call({1, 2, 4}, params, context);
59 
60  EXPECT_EQ(results[0], 0);
61  EXPECT_EQ(results[1], 13 * 13);
62  EXPECT_EQ(results[2], 13 * 13);
63  EXPECT_EQ(results[3], 0);
64  EXPECT_EQ(results[4], 13 * 13);
65  }
66 }
67 
68 class ConcatVectorsFunction : public MultiFunction {
69  public:
70  ConcatVectorsFunction()
71  {
72  static MFSignature signature = create_signature();
73  this->set_signature(&signature);
74  }
75 
76  static MFSignature create_signature()
77  {
78  MFSignatureBuilder signature{"Concat Vectors"};
79  signature.vector_mutable<int>("A");
80  signature.vector_input<int>("B");
81  return signature.build();
82  }
83 
84  void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
85  {
86  GVectorArray &a = params.vector_mutable(0);
87  const GVVectorArray &b = params.readonly_vector_input(1);
88  a.extend(mask, b);
89  }
90 };
91 
92 class AppendFunction : public MultiFunction {
93  public:
94  AppendFunction()
95  {
96  static MFSignature signature = create_signature();
97  this->set_signature(&signature);
98  }
99 
100  static MFSignature create_signature()
101  {
102  MFSignatureBuilder signature{"Append"};
103  signature.vector_mutable<int>("Vector");
104  signature.single_input<int>("Value");
105  return signature.build();
106  }
107 
108  void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
109  {
110  GVectorArray_TypedMutableRef<int> vectors = params.vector_mutable<int>(0);
111  const VArray<int> &values = params.readonly_single_input<int>(1);
112 
113  for (int64_t i : mask) {
114  vectors.append(i, values[i]);
115  }
116  }
117 };
118 
119 class SumVectorFunction : public MultiFunction {
120  public:
121  SumVectorFunction()
122  {
123  static MFSignature signature = create_signature();
124  this->set_signature(&signature);
125  }
126 
127  static MFSignature create_signature()
128  {
129  MFSignatureBuilder signature{"Sum Vectors"};
130  signature.vector_input<int>("Vector");
131  signature.single_output<int>("Sum");
132  return signature.build();
133  }
134 
135  void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
136  {
137  const VVectorArray<int> &vectors = params.readonly_vector_input<int>(0);
138  MutableSpan<int> sums = params.uninitialized_single_output<int>(1);
139 
140  for (int64_t i : mask) {
141  int sum = 0;
142  for (int j : IndexRange(vectors.get_vector_size(i))) {
143  sum += vectors.get_vector_element(i, j);
144  }
145  sums[i] = sum;
146  }
147  }
148 };
149 
150 class CreateRangeFunction : public MultiFunction {
151  public:
152  CreateRangeFunction()
153  {
154  static MFSignature signature = create_signature();
155  this->set_signature(&signature);
156  }
157 
158  static MFSignature create_signature()
159  {
160  MFSignatureBuilder signature{"Create Range"};
161  signature.single_input<int>("Size");
162  signature.vector_output<int>("Range");
163  return signature.build();
164  }
165 
166  void call(IndexMask mask, MFParams params, MFContext UNUSED(context)) const override
167  {
168  const VArray<int> &sizes = params.readonly_single_input<int>(0, "Size");
169  GVectorArray_TypedMutableRef<int> ranges = params.vector_output<int>(1, "Range");
170 
171  for (int64_t i : mask) {
172  int size = sizes[i];
173  for (int j : IndexRange(size)) {
174  ranges.append(i, j);
175  }
176  }
177  }
178 };
179 
180 TEST(multi_function_network, Test2)
181 {
182  CustomMF_SI_SO<int, int> add_3_fn("add 3", [](int value) { return value + 3; });
183 
184  ConcatVectorsFunction concat_vectors_fn;
185  AppendFunction append_fn;
186  SumVectorFunction sum_fn;
187  CreateRangeFunction create_range_fn;
188 
189  MFNetwork network;
190 
191  MFOutputSocket &input1 = network.add_input("Input 1", MFDataType::ForVector<int>());
192  MFOutputSocket &input2 = network.add_input("Input 2", MFDataType::ForSingle<int>());
193  MFInputSocket &output1 = network.add_output("Output 1", MFDataType::ForVector<int>());
194  MFInputSocket &output2 = network.add_output("Output 2", MFDataType::ForSingle<int>());
195 
196  MFNode &node1 = network.add_function(add_3_fn);
197  MFNode &node2 = network.add_function(create_range_fn);
198  MFNode &node3 = network.add_function(concat_vectors_fn);
199  MFNode &node4 = network.add_function(sum_fn);
200  MFNode &node5 = network.add_function(append_fn);
201  MFNode &node6 = network.add_function(sum_fn);
202 
203  network.add_link(input2, node1.input(0));
204  network.add_link(node1.output(0), node2.input(0));
205  network.add_link(node2.output(0), node3.input(1));
206  network.add_link(input1, node3.input(0));
207  network.add_link(input1, node4.input(0));
208  network.add_link(node4.output(0), node5.input(1));
209  network.add_link(node3.output(0), node5.input(0));
210  network.add_link(node5.output(0), node6.input(0));
211  network.add_link(node3.output(0), output1);
212  network.add_link(node6.output(0), output2);
213 
214  // std::cout << network.to_dot() << "\n\n";
215 
216  MFNetworkEvaluator network_fn{{&input1, &input2}, {&output1, &output2}};
217 
218  {
219  Array<int> input_value_1 = {3, 6};
220  int input_value_2 = 4;
221 
222  GVectorArray output_value_1(CPPType::get<int32_t>(), 5);
223  Array<int> output_value_2(5, -1);
224 
225  MFParamsBuilder params(network_fn, 5);
226  GVVectorArrayForSingleGSpan inputs_1{input_value_1.as_span(), 5};
227  params.add_readonly_vector_input(inputs_1);
228  params.add_readonly_single_input(&input_value_2);
229  params.add_vector_output(output_value_1);
230  params.add_uninitialized_single_output(output_value_2.as_mutable_span());
231 
232  MFContextBuilder context;
233 
234  network_fn.call({1, 2, 4}, params, context);
235 
236  EXPECT_EQ(output_value_1[0].size(), 0);
237  EXPECT_EQ(output_value_1[1].size(), 9);
238  EXPECT_EQ(output_value_1[2].size(), 9);
239  EXPECT_EQ(output_value_1[3].size(), 0);
240  EXPECT_EQ(output_value_1[4].size(), 9);
241 
242  EXPECT_EQ(output_value_2[0], -1);
243  EXPECT_EQ(output_value_2[1], 39);
244  EXPECT_EQ(output_value_2[2], 39);
245  EXPECT_EQ(output_value_2[3], -1);
246  EXPECT_EQ(output_value_2[4], 39);
247  }
248  {
249  GVectorArray input_value_1(CPPType::get<int32_t>(), 3);
250  GVectorArray_TypedMutableRef<int> input_value_1_ref{input_value_1};
251  input_value_1_ref.extend(0, {3, 4, 5});
252  input_value_1_ref.extend(1, {1, 2});
253 
254  Array<int> input_value_2 = {4, 2, 3};
255 
256  GVectorArray output_value_1(CPPType::get<int32_t>(), 3);
257  Array<int> output_value_2(3, -1);
258 
259  MFParamsBuilder params(network_fn, 3);
260  params.add_readonly_vector_input(input_value_1);
261  params.add_readonly_single_input(input_value_2.as_span());
262  params.add_vector_output(output_value_1);
263  params.add_uninitialized_single_output(output_value_2.as_mutable_span());
264 
265  MFContextBuilder context;
266 
267  network_fn.call({0, 1, 2}, params, context);
268 
269  EXPECT_EQ(output_value_1[0].size(), 10);
270  EXPECT_EQ(output_value_1[1].size(), 7);
271  EXPECT_EQ(output_value_1[2].size(), 6);
272 
273  EXPECT_EQ(output_value_2[0], 45);
274  EXPECT_EQ(output_value_2[1], 16);
275  EXPECT_EQ(output_value_2[2], 15);
276  }
277 }
278 
279 } // namespace
280 } // namespace blender::fn::tests
EXPECT_EQ(BLI_expr_pylike_eval(expr, nullptr, 0, &result), EXPR_PYLIKE_INVALID)
#define UNUSED(x)
static DBVT_INLINE btScalar size(const btDbvtVolume &a)
Definition: btDbvt.cpp:52
static T sum(const btAlignedObjectArray< T > &items)
uiWidgetBaseParameters params[MAX_WIDGET_BASE_BATCH]
static unsigned a[3]
Definition: RandGen.cpp:92
TEST(cpp_type, Size)
struct SELECTID_Context context
Definition: select_engine.c:47
__int64 int64_t
Definition: stdint.h:92
ccl_device_inline float4 mask(const int4 &mask, const float4 &a)