Blender V5.0
socket_value_inference.cc
Go to the documentation of this file.
1/* SPDX-FileCopyrightText: 2024 Blender Authors
2 *
3 * SPDX-License-Identifier: GPL-2.0-or-later */
4
5#include <regex>
6
7#include "NOD_menu_value.hh"
12
13#include "DNA_anim_types.h"
14#include "DNA_material_types.h"
15#include "DNA_node_types.h"
16
20#include "BKE_node_runtime.hh"
22
23#include "ANIM_action.hh"
25
26#include "BLI_listbase.h"
27#include "BLI_stack.hh"
28
29namespace blender::nodes {
30
32 private:
33 ResourceScope &scope_;
34 bke::ComputeContextCache &compute_context_cache_;
35
36 Stack<SocketInContext> value_tasks_;
41 Map<SocketInContext, InferenceValue> all_socket_values_;
42
43 FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn_;
44
49 Set<const bNodeSocket *> animated_sockets_;
50 Set<const bNodeTree *> trees_with_handled_animation_data_;
51 std::optional<Span<bool>> top_level_ignored_inputs_;
52
53 const bNodeTree &root_tree_;
54
55 public:
57 const bNodeTree &tree,
58 ResourceScope &scope,
59 bke::ComputeContextCache &compute_context_cache,
60 const FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn,
61 const std::optional<Span<bool>> top_level_ignored_inputs)
62 : scope_(scope),
63 compute_context_cache_(compute_context_cache),
64 group_input_value_fn_(group_input_value_fn),
65 top_level_ignored_inputs_(top_level_ignored_inputs),
66 root_tree_(tree)
67 {
68 root_tree_.ensure_topology_cache();
69 root_tree_.ensure_interface_cache();
70 this->ensure_animation_data_processed(root_tree_);
71 }
72
74 {
75 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(socket);
76 if (value.has_value()) {
77 return *value;
78 }
79 if (socket->owner_tree().has_available_link_cycle()) {
81 }
82
83 BLI_assert(value_tasks_.is_empty());
84 value_tasks_.push(socket);
85
86 while (!value_tasks_.is_empty()) {
87 const SocketInContext &socket = value_tasks_.peek();
88 this->value_task(socket);
89 if (&socket == &value_tasks_.peek()) {
90 /* The task is finished if it hasn't added any new task it depends on. */
91 value_tasks_.pop();
92 }
93 }
94
95 return all_socket_values_.lookup(socket);
96 }
97
98 private:
99 void value_task(const SocketInContext &socket)
100 {
101 if (all_socket_values_.contains(socket)) {
102 /* Task is done already. */
103 return;
104 }
105 const bNode &node = socket->owner_node();
106 if (node.is_undefined() && !node.is_custom_group()) {
107 all_socket_values_.add_new(socket, InferenceValue::Unknown());
108 return;
109 }
110 const CPPType *base_type = socket->typeinfo->base_cpp_type;
111 if (!base_type) {
112 /* The socket type is unknown for some reason (maybe a socket type from the future?). */
113 all_socket_values_.add_new(socket, InferenceValue::Unknown());
114 return;
115 }
116 if (socket->is_input()) {
117 this->value_task__input(socket);
118 }
119 else {
120 this->value_task__output(socket);
121 }
122 }
123
124 void value_task__output(const SocketInContext &socket)
125 {
126 const NodeInContext node = socket.owner_node();
127 if (node->is_muted()) {
128 this->value_task__output__muted_node(socket);
129 return;
130 }
131 switch (node->type_legacy) {
132 case NODE_GROUP:
133 case NODE_CUSTOM_GROUP: {
134 this->value_task__output__group_node(socket);
135 return;
136 }
137 case NODE_GROUP_INPUT: {
138 this->value_task__output__group_input_node(socket);
139 return;
140 }
141 case NODE_REROUTE: {
142 this->value_task__output__reroute_node(socket);
143 return;
144 }
145 case GEO_NODE_SWITCH: {
146 this->value_task__output__generic_switch(
148 return;
149 }
151 this->value_task__output__generic_switch(
153 return;
154 }
156 if (socket->index() == 0) {
157 this->value_task__output__generic_switch(
159 }
160 else {
161 this->value_task__output__menu_switch_selection(socket);
162 }
163 return;
164 }
165 case SH_NODE_MIX: {
166 this->value_task__output__generic_switch(
168 return;
169 }
170 case SH_NODE_MIX_SHADER: {
171 this->value_task__output__generic_switch(
173 return;
174 }
175 case SH_NODE_MATH: {
176 this->value_task__output__float_math(socket);
177 return;
178 }
179 case SH_NODE_VECTOR_MATH: {
180 this->value_task__output__vector_math(socket);
181 return;
182 }
184 this->value_task__output__integer_math(socket);
185 return;
186 }
188 this->value_task__output__boolean_math(socket);
189 return;
190 }
191 case GEO_NODE_WARNING: {
192 this->value_task__output__warning(socket);
193 return;
194 }
195 default: {
196 if (node->is_type("NodeEnableOutput")) {
197 this->value_task__output__enable_output(socket);
198 return;
199 }
200 if (node->typeinfo->build_multi_function) {
201 this->value_task__output__multi_function_node(socket);
202 return;
203 }
204 break;
205 }
206 }
207 /* If none of the above cases work, the socket value is set to null which means that it is
208 * unknown/dynamic. */
209 all_socket_values_.add_new(socket, InferenceValue::Unknown());
210 }
211
212 void value_task__output__group_node(const SocketInContext &socket)
213 {
214 const NodeInContext node = socket.owner_node();
215 const bNodeTree *group = reinterpret_cast<const bNodeTree *>(node->id);
216 if (!group || ID_MISSING(&group->id)) {
217 all_socket_values_.add_new(socket, InferenceValue::Unknown());
218 return;
219 }
220 group->ensure_topology_cache();
221 if (group->has_available_link_cycle()) {
222 all_socket_values_.add_new(socket, InferenceValue::Unknown());
223 return;
224 }
225 this->ensure_animation_data_processed(*group);
226 const bNode *group_output_node = group->group_output_node();
227 if (!group_output_node) {
228 /* Can't compute the value if the group does not have an output node. */
229 all_socket_values_.add_new(socket, InferenceValue::Unknown());
230 return;
231 }
232 const ComputeContext &group_context = compute_context_cache_.for_group_node(
233 socket.context, node->identifier, &node->owner_tree());
234 const SocketInContext socket_in_group{&group_context,
235 &group_output_node->input_socket(socket->index())};
236 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(socket_in_group);
237 if (!value.has_value()) {
238 this->push_value_task(socket_in_group);
239 return;
240 }
241 all_socket_values_.add_new(socket, *value);
242 }
243
244 void value_task__output__group_input_node(const SocketInContext &socket)
245 {
246 const bool is_root_context = socket.context == nullptr;
247 if (is_root_context) {
248 InferenceValue value = InferenceValue::Unknown();
249 if (group_input_value_fn_) {
250 value = group_input_value_fn_(socket->index());
251 }
252 all_socket_values_.add_new(socket, value);
253 return;
254 }
255
256 const bke::GroupNodeComputeContext &group_context =
257 *static_cast<const bke::GroupNodeComputeContext *>(socket.context);
258 const SocketInContext group_node_input{group_context.parent(),
259 &group_context.node()->input_socket(socket->index())};
260 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(group_node_input);
261 if (!value.has_value()) {
262 this->push_value_task(group_node_input);
263 return;
264 }
265 all_socket_values_.add_new(socket, *value);
266 }
267
268 void value_task__output__reroute_node(const SocketInContext &socket)
269 {
270 const SocketInContext input_socket = socket.owner_node().input_socket(0);
271 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(input_socket);
272 if (!value.has_value()) {
273 this->push_value_task(input_socket);
274 return;
275 }
276 all_socket_values_.add_new(socket, *value);
277 }
278
279 void value_task__output__menu_switch_selection(const SocketInContext &socket)
280 {
281 const NodeInContext node = socket.owner_node();
282 const SocketInContext input_socket = node.input_socket(0);
283 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(input_socket);
284 if (!value.has_value()) {
285 this->push_value_task(input_socket);
286 return;
287 }
288 const std::optional<MenuValue> menu_value = value->get_if_primitive<MenuValue>();
289 if (!menu_value.has_value()) {
290 all_socket_values_.add_new(socket, InferenceValue::Unknown());
291 return;
292 }
293 const NodeMenuSwitch &storage = *static_cast<const NodeMenuSwitch *>(node->storage);
294 const NodeEnumItem &item = storage.enum_definition.items_array[socket->index() - 1];
295 const bool is_selected = item.identifier == menu_value->value;
296 all_socket_values_.add_new(socket, this->make_primitive_inference_value(is_selected));
297 }
298
299 void value_task__output__float_math(const SocketInContext &socket)
300 {
301 const NodeInContext node = socket.owner_node();
302 const NodeMathOperation operation = NodeMathOperation(node->custom1);
303 switch (operation) {
304 case NODE_MATH_MULTIPLY: {
305 this->value_task__output__generic_eval(
306 socket, [&](const Span<InferenceValue> inputs) -> std::optional<InferenceValue> {
307 const std::optional<float> a = inputs[0].get_if_primitive<float>();
308 const std::optional<float> b = inputs[1].get_if_primitive<float>();
309 if (a == 0.0f || b == 0.0f) {
310 return this->make_primitive_inference_value(0.0f);
311 }
312 if (a.has_value() && b.has_value()) {
313 return this->make_primitive_inference_value(*a * *b);
314 }
315 return std::nullopt;
316 });
317 break;
318 }
319 default: {
320 this->value_task__output__multi_function_node(socket);
321 break;
322 }
323 }
324 }
325
326 void value_task__output__vector_math(const SocketInContext &socket)
327 {
328 const NodeInContext node = socket.owner_node();
329 const NodeVectorMathOperation operation = NodeVectorMathOperation(node->custom1);
330 switch (operation) {
332 this->value_task__output__generic_eval(
333 socket, [&](const Span<InferenceValue> inputs) -> std::optional<InferenceValue> {
334 const std::optional<float3> a = inputs[0].get_if_primitive<float3>();
335 const std::optional<float3> b = inputs[1].get_if_primitive<float3>();
336 if (a == float3(0.0f) || b == float3(0.0f)) {
337 return this->make_primitive_inference_value(float3(0.0f));
338 }
339 if (a.has_value() && b.has_value()) {
340 return this->make_primitive_inference_value(float3(*a * *b));
341 }
342 return std::nullopt;
343 });
344 break;
345 }
347 this->value_task__output__generic_eval(
348 socket, [&](const Span<InferenceValue> inputs) -> std::optional<InferenceValue> {
349 const std::optional<float3> a = inputs[0].get_if_primitive<float3>();
350 const std::optional<float> scale = inputs[3].get_if_primitive<float>();
351 if (a == float3(0.0f) || scale == 0.0f) {
352 return this->make_primitive_inference_value(float3(0.0f));
353 }
354 if (a.has_value() && scale.has_value()) {
355 return this->make_primitive_inference_value(float3(*a * *scale));
356 }
357 return std::nullopt;
358 });
359 break;
360 }
361 default: {
362 this->value_task__output__multi_function_node(socket);
363 break;
364 }
365 }
366 }
367
368 void value_task__output__integer_math(const SocketInContext &socket)
369 {
370 const NodeInContext node = socket.owner_node();
371 const NodeIntegerMathOperation operation = NodeIntegerMathOperation(node->custom1);
372 switch (operation) {
374 this->value_task__output__generic_eval(
375 socket, [&](const Span<InferenceValue> inputs) -> std::optional<InferenceValue> {
376 const std::optional<int> a = inputs[0].get_if_primitive<int>();
377 const std::optional<int> b = inputs[1].get_if_primitive<int>();
378 if (a == 0 || b == 0) {
379 return this->make_primitive_inference_value(0);
380 }
381 if (a.has_value() && b.has_value()) {
382 return this->make_primitive_inference_value(*a * *b);
383 }
384 return std::nullopt;
385 });
386 break;
387 }
388 default: {
389 this->value_task__output__multi_function_node(socket);
390 break;
391 }
392 }
393 }
394
395 void value_task__output__boolean_math(const SocketInContext &socket)
396 {
397 const NodeInContext node = socket.owner_node();
398 const NodeBooleanMathOperation operation = NodeBooleanMathOperation(node->custom1);
399
400 const auto handle_binary_op =
401 [&](FunctionRef<std::optional<bool>(std::optional<bool>, std::optional<bool>)> fn) {
402 this->value_task__output__generic_eval(
403 socket, [&](const Span<InferenceValue> inputs) -> std::optional<InferenceValue> {
404 const std::optional<bool> a = inputs[0].get_if_primitive<bool>();
405 const std::optional<bool> b = inputs[1].get_if_primitive<bool>();
406 const std::optional<bool> result = fn(a, b);
407 if (result.has_value()) {
408 return this->make_primitive_inference_value(*result);
409 }
410 return std::nullopt;
411 });
412 };
413 switch (operation) {
415 handle_binary_op(
416 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
417 if (a == false || b == false) {
418 return false;
419 }
420 if (a.has_value() && b.has_value()) {
421 return *a && *b;
422 }
423 return std::nullopt;
424 });
425 break;
426 }
428 handle_binary_op(
429 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
430 if (a == true || b == true) {
431 return true;
432 }
433 if (a.has_value() && b.has_value()) {
434 return *a || *b;
435 }
436 return std::nullopt;
437 });
438 break;
439 }
441 handle_binary_op(
442 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
443 if (a == false || b == false) {
444 return true;
445 }
446 if (a.has_value() && b.has_value()) {
447 return !(*a && *b);
448 }
449 return std::nullopt;
450 });
451 break;
452 }
454 handle_binary_op(
455 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
456 if (a == true || b == true) {
457 return false;
458 }
459 if (a.has_value() && b.has_value()) {
460 return !(*a || *b);
461 }
462 return std::nullopt;
463 });
464 break;
465 }
467 handle_binary_op(
468 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
469 if (a == false || b == true) {
470 return true;
471 }
472 if (a.has_value() && b.has_value()) {
473 return !*a || *b;
474 }
475 return std::nullopt;
476 });
477 break;
478 }
480 handle_binary_op(
481 [](const std::optional<bool> &a, const std::optional<bool> &b) -> std::optional<bool> {
482 if (a == false || b == true) {
483 return false;
484 }
485 if (a.has_value() && b.has_value()) {
486 return *a && !*b;
487 }
488 return std::nullopt;
489 });
490 break;
491 }
492 default: {
493 this->value_task__output__multi_function_node(socket);
494 break;
495 }
496 }
497 }
498
499 void value_task__output__warning(const SocketInContext &socket)
500 {
501 const NodeInContext node = socket.owner_node();
502 const SocketInContext show_input_socket = node.input_socket(0);
503 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(show_input_socket);
504 if (!value.has_value()) {
505 this->push_value_task(show_input_socket);
506 return;
507 }
508 all_socket_values_.add_new(socket, *value);
509 }
510
511 void value_task__output__enable_output(const SocketInContext &socket)
512 {
513 const NodeInContext node = socket.owner_node();
514 const SocketInContext enable_input_socket = node.input_socket(0);
515 const SocketInContext value_input_socket = node.input_socket(1);
516
517 const std::optional<InferenceValue> keep_value = all_socket_values_.lookup_try(
518 enable_input_socket);
519 if (!keep_value.has_value()) {
520 this->push_value_task(enable_input_socket);
521 return;
522 }
523 if (!keep_value->is_primitive_value()) {
524 all_socket_values_.add_new(socket, InferenceValue::Unknown());
525 return;
526 }
527 const bool keep = keep_value->get_primitive<bool>();
528 if (!keep) {
529 const CPPType &type = *socket->typeinfo->base_cpp_type;
530 all_socket_values_.add_new(socket, InferenceValue::from_primitive(type.default_value()));
531 return;
532 }
533 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(value_input_socket);
534 if (!value.has_value()) {
535 this->push_value_task(value_input_socket);
536 return;
537 }
538 all_socket_values_.add_new(socket, *value);
539 }
540
545 void value_task__output__generic_switch(
546 const SocketInContext &socket,
547 const FunctionRef<bool(const SocketInContext &socket, InferenceValue condition)>
548 is_selected_socket)
549 {
550 const NodeInContext node = socket.owner_node();
551 BLI_assert(node->input_sockets().size() >= 1);
552 BLI_assert(node->output_sockets().size() >= 1);
553
554 const SocketInContext condition_socket{socket.context,
555 get_first_available_bsocket(node->input_sockets())};
556 const std::optional<InferenceValue> condition_value = all_socket_values_.lookup_try(
557 condition_socket);
558 if (!condition_value.has_value()) {
559 this->push_value_task(condition_socket);
560 return;
561 }
562 if (condition_value->is_unknown()) {
563 /* The condition value is not a simple static value, so the output is unknown. */
564 all_socket_values_.add_new(socket, InferenceValue::Unknown());
565 return;
566 }
567 Vector<const bNodeSocket *> selected_inputs;
568 for (const int input_i :
569 node->input_sockets().index_range().drop_front(condition_socket->index() + 1))
570 {
571 const SocketInContext input_socket = node.input_socket(input_i);
572 if (!input_socket->is_available()) {
573 continue;
574 }
575 if (input_socket->type == SOCK_CUSTOM && STREQ(input_socket->idname, "NodeSocketVirtual")) {
576 continue;
577 }
578 const bool is_selected = is_selected_socket(input_socket, *condition_value);
579 if (is_selected) {
580 selected_inputs.append(input_socket.socket);
581 }
582 }
583 if (selected_inputs.is_empty()) {
584 all_socket_values_.add_new(socket, InferenceValue::Unknown());
585 return;
586 }
587 if (selected_inputs.size() == 1) {
588 /* A single input is selected, so just pass through this value without regarding others. */
589 const SocketInContext selected_input{socket.context, selected_inputs[0]};
590 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
591 selected_input);
592 if (!input_value.has_value()) {
593 this->push_value_task(selected_input);
594 return;
595 }
596 all_socket_values_.add_new(socket, *input_value);
597 return;
598 }
599
600 /* Multiple inputs are selected. */
601 if (node->typeinfo->build_multi_function) {
602 /* Try to compute the output value from the multiple selected inputs. */
603 this->value_task__output__multi_function_node(socket);
604 return;
605 }
606 /* Can't compute the output value, so set it to be unknown. */
607 all_socket_values_.add_new(socket, InferenceValue::Unknown());
608 }
609
610 void value_task__output__generic_eval(
611 const SocketInContext &socket,
612 const FunctionRef<std::optional<InferenceValue>(Span<InferenceValue> inputs)> eval_fn)
613 {
614 const NodeInContext node = socket.owner_node();
615 const int inputs_num = node->input_sockets().size();
616
617 Array<InferenceValue, 16> input_values(inputs_num, InferenceValue::Unknown());
618 std::optional<int> next_unknown_input_index;
619 for (const int input_i : IndexRange(inputs_num)) {
620 const SocketInContext input_socket = node.input_socket(input_i);
621 if (!input_socket->is_available()) {
622 continue;
623 }
624 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
625 input_socket);
626 if (!input_value.has_value()) {
627 next_unknown_input_index = input_i;
628 break;
629 }
630 input_values[input_i] = *input_value;
631 }
632 const std::optional<InferenceValue> output_value = eval_fn(input_values);
633 if (output_value.has_value()) {
634 /* Was able to compute the output value. */
635 all_socket_values_.add_new(socket, *output_value);
636 return;
637 }
638 if (!next_unknown_input_index.has_value()) {
639 /* The output is still unknown even though we know as much about the inputs as possible
640 * already. */
641 all_socket_values_.add_new(socket, InferenceValue::Unknown());
642 return;
643 }
644 /* Request the next input socket. */
645 const SocketInContext next_input = node.input_socket(*next_unknown_input_index);
646 this->push_value_task(next_input);
647 }
648
649 void value_task__output__multi_function_node(const SocketInContext &socket)
650 {
651 const NodeInContext node = socket.owner_node();
652 const int inputs_num = node->input_sockets().size();
653
654 /* Gather all input values are return early if any of them is not known. */
655 Vector<const void *> input_values(inputs_num);
656 for (const int input_i : IndexRange(inputs_num)) {
657 const SocketInContext input_socket = node.input_socket(input_i);
658 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
659 input_socket);
660 if (!input_value.has_value()) {
661 this->push_value_task(input_socket);
662 return;
663 }
664 if (!input_value->is_primitive_value()) {
665 all_socket_values_.add_new(socket, InferenceValue::Unknown());
666 return;
667 }
668 input_values[input_i] = input_value->get_primitive_ptr();
669 }
670
671 /* Get the multi-function for the node. */
672 NodeMultiFunctionBuilder builder{*node.node, node->owner_tree()};
673 node->typeinfo->build_multi_function(builder);
674 const mf::MultiFunction &fn = builder.function();
675
676 /* We only evaluate the node for a single value here. */
677 const IndexMask mask(1);
678
679 /* Prepare parameters for the multi-function evaluation. */
680 mf::ParamsBuilder params{fn, &mask};
681 for (const int input_i : IndexRange(inputs_num)) {
682 const SocketInContext input_socket = node.input_socket(input_i);
683 if (!input_socket->is_available()) {
684 continue;
685 }
686 params.add_readonly_single_input(
687 GPointer(input_socket->typeinfo->base_cpp_type, input_values[input_i]));
688 }
689 for (const int output_i : node->output_sockets().index_range()) {
690 const SocketInContext output_socket = node.output_socket(output_i);
691 if (!output_socket->is_available()) {
692 continue;
693 }
694 /* Allocate memory for the output value. */
695 const CPPType &base_type = *output_socket->typeinfo->base_cpp_type;
696 void *value = scope_.allocate_owned(base_type);
697 params.add_uninitialized_single_output(GMutableSpan(base_type, value, 1));
698 all_socket_values_.add_new(output_socket, InferenceValue::from_primitive(value));
699 }
700 mf::ContextBuilder context;
701 /* Actually evaluate the multi-function. The outputs will be written into the memory allocated
702 * earlier, which has been added to #all_socket_values_ already. */
703 fn.call(mask, params, context);
704 }
705
706 void value_task__output__muted_node(const SocketInContext &socket)
707 {
708 const NodeInContext node = socket.owner_node();
709
710 SocketInContext input_socket;
711 for (const bNodeLink &internal_link : node->internal_links()) {
712 if (internal_link.tosock == socket.socket) {
713 input_socket = SocketInContext{socket.context, internal_link.fromsock};
714 break;
715 }
716 }
717 if (!input_socket) {
718 /* The output does not have an internal link to an input. */
719 all_socket_values_.add_new(socket, InferenceValue::Unknown());
720 return;
721 }
722 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(input_socket);
723 if (!input_value.has_value()) {
724 this->push_value_task(input_socket);
725 return;
726 }
727 const InferenceValue converted_value = this->convert_type_if_necessary(
728 *input_value, *input_socket.socket, *socket.socket);
729 all_socket_values_.add_new(socket, converted_value);
730 }
731
732 void value_task__input(const SocketInContext &socket)
733 {
734 if (socket->is_multi_input()) {
735 /* Can't know the single value of a multi-input. */
736 all_socket_values_.add_new(socket, InferenceValue::Unknown());
737 return;
738 }
739 const bNodeLink *source_link = nullptr;
740 const Span<const bNodeLink *> connected_links = socket->directly_linked_links();
741 for (const bNodeLink *link : connected_links) {
742 if (!link->is_used()) {
743 continue;
744 }
745 if (link->fromnode->is_dangling_reroute()) {
746 continue;
747 }
748 source_link = link;
749 break;
750 }
751 if (!source_link) {
752 this->value_task__input__unlinked(socket);
753 return;
754 }
755 this->value_task__input__linked({socket.context, source_link->fromsock}, socket);
756 }
757
758 void value_task__input__unlinked(const SocketInContext &socket)
759 {
760 if (this->treat_socket_as_unknown(socket)) {
761 all_socket_values_.add_new(socket, InferenceValue::Unknown());
762 return;
763 }
764 if (animated_sockets_.contains(socket.socket)) {
765 /* The value of animated sockets is not known statically. */
766 all_socket_values_.add_new(socket, InferenceValue::Unknown());
767 return;
768 }
769 if (const SocketDeclaration *socket_decl = socket.socket->runtime->declaration) {
770 if (socket_decl->input_field_type == InputSocketFieldType::Implicit) {
771 /* Implicit fields inputs don't have a single static value. */
772 all_socket_values_.add_new(socket, InferenceValue::Unknown());
773 return;
774 }
775 }
776
777 void *value_buffer = scope_.allocate_owned(*socket->typeinfo->base_cpp_type);
778 socket->typeinfo->get_base_cpp_value(socket->default_value, value_buffer);
779 all_socket_values_.add_new(socket, InferenceValue::from_primitive(value_buffer));
780 }
781
782 void value_task__input__linked(const SocketInContext &from_socket,
783 const SocketInContext &to_socket)
784 {
785 const std::optional<InferenceValue> from_value = all_socket_values_.lookup_try(from_socket);
786 if (!from_value.has_value()) {
787 this->push_value_task(from_socket);
788 return;
789 }
790 const InferenceValue converted_value = this->convert_type_if_necessary(
791 *from_value, *from_socket.socket, *to_socket.socket);
792 all_socket_values_.add_new(to_socket, converted_value);
793 }
794
795 InferenceValue convert_type_if_necessary(const InferenceValue &src,
796 const bNodeSocket &from_socket,
797 const bNodeSocket &to_socket)
798 {
799 if (!src.is_primitive_value()) {
801 }
802 const CPPType *from_type = from_socket.typeinfo->base_cpp_type;
803 const CPPType *to_type = to_socket.typeinfo->base_cpp_type;
804 if (from_type == to_type) {
805 return src;
806 }
807 if (!to_type) {
809 }
810 const bke::DataTypeConversions &conversions = bke::get_implicit_type_conversions();
811 if (!conversions.is_convertible(*from_type, *to_type)) {
813 }
814 void *dst = scope_.allocate_owned(*to_type);
815 conversions.convert_to_uninitialized(*from_type, *to_type, src.get_primitive_ptr(), dst);
817 }
818
819 bool treat_socket_as_unknown(const SocketInContext &socket) const
820 {
821 if (!top_level_ignored_inputs_.has_value()) {
822 return false;
823 }
824 if (socket.context) {
825 return false;
826 }
827 if (socket->is_output()) {
828 return false;
829 }
830 return (*top_level_ignored_inputs_)[socket->index_in_all_inputs()];
831 }
832
833 void ensure_animation_data_processed(const bNodeTree &tree)
834 {
835 if (!trees_with_handled_animation_data_.add(&tree)) {
836 return;
837 }
838 if (!tree.adt) {
839 return;
840 }
841
842 static std::regex pattern(R"#(nodes\‍["(.*)"\‍].inputs\‍[(\d+)\‍].default_value)#");
843 MultiValueMap<StringRef, int> animated_inputs_by_node_name;
844 auto handle_rna_path = [&](const char *rna_path) {
845 std::cmatch match;
846 if (!std::regex_match(rna_path, match, pattern)) {
847 return;
848 }
849 const StringRef node_name{match[1].first, match[1].second - match[1].first};
850 const int socket_index = std::stoi(match[2]);
851 animated_inputs_by_node_name.add(node_name, socket_index);
852 };
853
854 /* Gather all inputs controlled by fcurves. */
855 if (tree.adt->action) {
857 tree.adt->action->wrap(), tree.adt->slot_handle, [&](const FCurve &fcurve) {
858 handle_rna_path(fcurve.rna_path);
859 });
860 }
861 /* Gather all inputs controlled by drivers. */
862 LISTBASE_FOREACH (const FCurve *, driver, &tree.adt->drivers) {
863 handle_rna_path(driver->rna_path);
864 }
865
866 /* Actually find the #bNodeSocket for each controlled input. */
867 if (!animated_inputs_by_node_name.is_empty()) {
868 for (const bNode *node : tree.all_nodes()) {
869 const Span<int> animated_inputs = animated_inputs_by_node_name.lookup(node->name);
870 const Span<const bNodeSocket *> input_sockets = node->input_sockets();
871 for (const int socket_index : animated_inputs) {
872 if (socket_index < 0 || socket_index >= input_sockets.size()) {
873 /* This can happen when the animation data is not immediately updated after a socket is
874 * removed. */
875 continue;
876 }
877 const bNodeSocket &socket = *input_sockets[socket_index];
878 animated_sockets_.add(&socket);
879 }
880 }
881 }
882 }
883
884 void push_value_task(const SocketInContext &socket)
885 {
886 value_tasks_.push(socket);
887 }
888
889 template<typename T> InferenceValue make_primitive_inference_value(const T &value)
890 {
891 static_assert(is_same_any_v<std::decay_t<T>, bool, float, int, float3>);
892 return InferenceValue::from_primitive(&scope_.construct<T>(value));
893 }
894
895 static const bNodeSocket *get_first_available_bsocket(const Span<const bNodeSocket *> sockets)
896 {
897 for (const bNodeSocket *socket : sockets) {
898 if (socket->is_available()) {
899 return socket;
900 }
901 }
902 return nullptr;
903 }
904};
905
907 const bNodeTree &tree,
908 ResourceScope &scope,
909 bke::ComputeContextCache &compute_context_cache,
910 const FunctionRef<InferenceValue(int group_input_i)> group_input_value_fn,
911 const std::optional<Span<bool>> top_level_ignored_inputs)
912 : impl_(scope.construct<SocketValueInferencerImpl>(
913 tree, scope, compute_context_cache, group_input_value_fn, top_level_ignored_inputs))
914{
915}
916
918{
919 return impl_.get_socket_value(socket);
920}
921
923
924bool is_socket_selected__switch(const SocketInContext &socket, const InferenceValue &condition)
925{
926 if (!condition.is_primitive_value()) {
927 return true;
928 }
929 const bool is_true = condition.get_primitive<bool>();
930 const int selected_index = is_true ? 2 : 1;
931 return socket->index() == selected_index;
932}
933
935 const InferenceValue &condition)
936{
937 if (!condition.is_primitive_value()) {
938 return true;
939 }
940 const int index = condition.get_primitive<int>();
941 return socket->index() == index + 1;
942}
943
945 const InferenceValue &condition)
946{
947 if (!condition.is_primitive_value()) {
948 return true;
949 }
950 const NodeMenuSwitch &storage = *static_cast<const NodeMenuSwitch *>(
951 socket->owner_node().storage);
952 const int menu_value = condition.get_primitive<int>();
953 const NodeEnumItem &item = storage.enum_definition.items_array[socket->index() - 1];
954 return menu_value == item.identifier;
955}
956
958{
959 if (!condition.is_primitive_value()) {
960 return true;
961 }
962 const NodeShaderMix &storage = *static_cast<const NodeShaderMix *>(socket.owner_node()->storage);
963 if (storage.data_type == SOCK_RGBA && storage.blend_type != MA_RAMP_BLEND) {
964 return true;
965 }
966
967 const bool clamp_factor = storage.clamp_factor != 0;
968 bool only_a = false;
969 bool only_b = false;
970 if (storage.data_type == SOCK_VECTOR && storage.factor_mode == NODE_MIX_MODE_NON_UNIFORM) {
971 const float3 mix_factor = condition.get_primitive<float3>();
972 if (clamp_factor) {
973 only_a = mix_factor.x <= 0.0f && mix_factor.y <= 0.0f && mix_factor.z <= 0.0f;
974 only_b = mix_factor.x >= 1.0f && mix_factor.y >= 1.0f && mix_factor.z >= 1.0f;
975 }
976 else {
977 only_a = float3{0.0f, 0.0f, 0.0f} == mix_factor;
978 only_b = float3{1.0f, 1.0f, 1.0f} == mix_factor;
979 }
980 }
981 else {
982 const float mix_factor = condition.get_primitive<float>();
983 if (clamp_factor) {
984 only_a = mix_factor <= 0.0f;
985 only_b = mix_factor >= 1.0f;
986 }
987 else {
988 only_a = mix_factor == 0.0f;
989 only_b = mix_factor == 1.0f;
990 }
991 }
992 if (only_a) {
993 if (STREQ(socket->name, "B")) {
994 return false;
995 }
996 }
997 if (only_b) {
998 if (STREQ(socket->name, "A")) {
999 return false;
1000 }
1001 }
1002 return true;
1003}
1004
1006 const InferenceValue &condition)
1007{
1008 if (!condition.is_primitive_value()) {
1009 return true;
1010 }
1011 const float mix_factor = condition.get_primitive<float>();
1012 if (mix_factor == 0.0f) {
1013 if (STREQ(socket->identifier, "Shader_001")) {
1014 return false;
1015 }
1016 }
1017 else if (mix_factor == 1.0f) {
1018 if (STREQ(socket->identifier, "Shader")) {
1019 return false;
1020 }
1021 }
1022 return true;
1023}
1024
1025} // namespace switch_node_inference_utils
1026
1027} // namespace blender::nodes
Functions and classes to work with Actions.
Functionality to iterate an Action in various ways.
#define NODE_REROUTE
Definition BKE_node.hh:813
#define NODE_CUSTOM_GROUP
Definition BKE_node.hh:816
#define NODE_GROUP
Definition BKE_node.hh:811
#define NODE_GROUP_INPUT
Definition BKE_node.hh:814
#define SH_NODE_MIX_SHADER
#define GEO_NODE_MENU_SWITCH
#define FN_NODE_INTEGER_MATH
#define SH_NODE_VECTOR_MATH
#define SH_NODE_MATH
#define FN_NODE_BOOLEAN_MATH
#define GEO_NODE_SWITCH
#define SH_NODE_MIX
#define GEO_NODE_WARNING
#define GEO_NODE_INDEX_SWITCH
#define BLI_assert(a)
Definition BLI_assert.h:46
#define LISTBASE_FOREACH(type, var, list)
#define STREQ(a, b)
#define ID_MISSING(_id)
Definition DNA_ID.h:692
struct FCurve FCurve
@ MA_RAMP_BLEND
struct NodeMenuSwitch NodeMenuSwitch
NodeVectorMathOperation
@ NODE_VECTOR_MATH_MULTIPLY
@ NODE_VECTOR_MATH_SCALE
NodeMathOperation
@ NODE_MATH_MULTIPLY
struct bNodeLink bNodeLink
@ NODE_MIX_MODE_NON_UNIFORM
struct bNode bNode
@ SOCK_VECTOR
@ SOCK_CUSTOM
@ SOCK_RGBA
struct bNodeTree bNodeTree
NodeBooleanMathOperation
@ NODE_BOOLEAN_MATH_IMPLY
@ NODE_BOOLEAN_MATH_AND
@ NODE_BOOLEAN_MATH_NAND
@ NODE_BOOLEAN_MATH_OR
@ NODE_BOOLEAN_MATH_NIMPLY
@ NODE_BOOLEAN_MATH_NOR
struct NodeEnumItem NodeEnumItem
NodeIntegerMathOperation
@ NODE_INTEGER_MATH_MULTIPLY
struct bNodeSocket bNodeSocket
const void * default_value() const
Span< Value > lookup(const Key &key) const
void add(const Key &key, const Value &value)
constexpr int64_t size() const
Definition BLI_span.hh:252
int64_t size() const
void append(const T &value)
bool is_empty() const
void add_new(const Key &key, const Value &value)
Definition BLI_map.hh:265
bool contains(const Key &key) const
Definition BLI_map.hh:353
static InferenceValue from_primitive(const void *value)
InferenceValue get_socket_value(const SocketInContext &socket)
SocketValueInferencerImpl(const bNodeTree &tree, ResourceScope &scope, bke::ComputeContextCache &compute_context_cache, const FunctionRef< InferenceValue(int group_input_i)> group_input_value_fn, const std::optional< Span< bool > > top_level_ignored_inputs)
SocketValueInferencer(const bNodeTree &tree, ResourceScope &scope, bke::ComputeContextCache &compute_context_cache, FunctionRef< InferenceValue(int group_input_i)> group_input_value_fn=nullptr, std::optional< Span< bool > > top_level_ignored_inputs=std::nullopt)
InferenceValue get_socket_value(const SocketInContext &socket)
nullptr float
KDTree_3d * tree
uiWidgetBaseParameters params[MAX_WIDGET_BASE_BATCH]
ccl_device_inline float2 mask(const MaskType mask, const float2 a)
#define T
void foreach_fcurve_in_action_slot(Action &action, slot_handle_t handle, FunctionRef< void(FCurve &fcurve)> callback)
const DataTypeConversions & get_implicit_type_conversions()
int context(const bContext *C, const char *member, bContextDataResult *result)
static Type to_type(const GPUType type)
MatBase< T, NumCol, NumRow > scale(const MatBase< T, NumCol, NumRow > &mat, const VectorT &scale)
bool is_socket_selected__shader_mix_node(const SocketInContext &socket, const InferenceValue &condition)
bool is_socket_selected__mix_node(const SocketInContext &socket, const InferenceValue &condition)
bool is_socket_selected__menu_switch(const SocketInContext &socket, const InferenceValue &condition)
bool is_socket_selected__index_switch(const SocketInContext &socket, const InferenceValue &condition)
bool is_socket_selected__switch(const SocketInContext &socket, const InferenceValue &condition)
constexpr bool is_same_any_v
VecBase< float, 3 > float3
static blender::bke::bNodeSocketTemplate inputs[]
NodeEnumItem * items_array
NodeEnumDefinition enum_definition
bNodeSocketTypeHandle * typeinfo
char identifier[64]