51 std::optional<Span<bool>> top_level_ignored_inputs_;
61 const std::optional<
Span<bool>> top_level_ignored_inputs)
63 compute_context_cache_(compute_context_cache),
64 group_input_value_fn_(group_input_value_fn),
65 top_level_ignored_inputs_(top_level_ignored_inputs),
68 root_tree_.ensure_topology_cache();
69 root_tree_.ensure_interface_cache();
70 this->ensure_animation_data_processed(root_tree_);
75 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(socket);
76 if (value.has_value()) {
79 if (socket->owner_tree().has_available_link_cycle()) {
84 value_tasks_.push(socket);
86 while (!value_tasks_.is_empty()) {
88 this->value_task(socket);
89 if (&socket == &value_tasks_.peek()) {
95 return all_socket_values_.lookup(socket);
101 if (all_socket_values_.
contains(socket)) {
105 const bNode &node = socket->owner_node();
106 if (node.is_undefined() && !node.is_custom_group()) {
116 if (socket->is_input()) {
117 this->value_task__input(socket);
120 this->value_task__output(socket);
124 void value_task__output(
const SocketInContext &socket)
126 const NodeInContext node = socket.owner_node();
127 if (node->is_muted()) {
128 this->value_task__output__muted_node(socket);
131 switch (node->type_legacy) {
134 this->value_task__output__group_node(socket);
138 this->value_task__output__group_input_node(socket);
142 this->value_task__output__reroute_node(socket);
146 this->value_task__output__generic_switch(
151 this->value_task__output__generic_switch(
156 if (socket->index() == 0) {
157 this->value_task__output__generic_switch(
161 this->value_task__output__menu_switch_selection(socket);
166 this->value_task__output__generic_switch(
171 this->value_task__output__generic_switch(
176 this->value_task__output__float_math(socket);
180 this->value_task__output__vector_math(socket);
184 this->value_task__output__integer_math(socket);
188 this->value_task__output__boolean_math(socket);
192 this->value_task__output__warning(socket);
196 if (node->is_type(
"NodeEnableOutput")) {
197 this->value_task__output__enable_output(socket);
200 if (node->typeinfo->build_multi_function) {
201 this->value_task__output__multi_function_node(socket);
212 void value_task__output__group_node(
const SocketInContext &socket)
214 const NodeInContext node = socket.owner_node();
220 group->ensure_topology_cache();
221 if (group->has_available_link_cycle()) {
225 this->ensure_animation_data_processed(*group);
226 const bNode *group_output_node = group->group_output_node();
227 if (!group_output_node) {
232 const ComputeContext &group_context = compute_context_cache_.for_group_node(
233 socket.context, node->identifier, &node->owner_tree());
234 const SocketInContext socket_in_group{&group_context,
235 &group_output_node->input_socket(socket->index())};
236 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(socket_in_group);
237 if (!value.has_value()) {
238 this->push_value_task(socket_in_group);
241 all_socket_values_.add_new(socket, *value);
244 void value_task__output__group_input_node(
const SocketInContext &socket)
246 const bool is_root_context = socket.context ==
nullptr;
247 if (is_root_context) {
249 if (group_input_value_fn_) {
250 value = group_input_value_fn_(socket->index());
252 all_socket_values_.add_new(socket, value);
256 const bke::GroupNodeComputeContext &group_context =
257 *
static_cast<const bke::GroupNodeComputeContext *
>(socket.context);
258 const SocketInContext group_node_input{group_context.parent(),
259 &group_context.node()->input_socket(socket->index())};
260 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(group_node_input);
261 if (!value.has_value()) {
262 this->push_value_task(group_node_input);
265 all_socket_values_.add_new(socket, *value);
268 void value_task__output__reroute_node(
const SocketInContext &socket)
270 const SocketInContext input_socket = socket.owner_node().input_socket(0);
271 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(input_socket);
272 if (!value.has_value()) {
273 this->push_value_task(input_socket);
276 all_socket_values_.add_new(socket, *value);
279 void value_task__output__menu_switch_selection(
const SocketInContext &socket)
281 const NodeInContext node = socket.owner_node();
282 const SocketInContext input_socket = node.input_socket(0);
283 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(input_socket);
284 if (!value.has_value()) {
285 this->push_value_task(input_socket);
288 const std::optional<MenuValue> menu_value = value->get_if_primitive<MenuValue>();
289 if (!menu_value.has_value()) {
295 const bool is_selected = item.
identifier == menu_value->value;
296 all_socket_values_.add_new(socket, this->make_primitive_inference_value(is_selected));
299 void value_task__output__float_math(
const SocketInContext &socket)
301 const NodeInContext node = socket.owner_node();
305 this->value_task__output__generic_eval(
306 socket, [&](
const Span<InferenceValue>
inputs) -> std::optional<InferenceValue> {
307 const std::optional<float> a =
inputs[0].get_if_primitive<
float>();
308 const std::optional<float>
b =
inputs[1].get_if_primitive<
float>();
309 if (a == 0.0f ||
b == 0.0f) {
310 return this->make_primitive_inference_value(0.0f);
312 if (a.has_value() &&
b.has_value()) {
313 return this->make_primitive_inference_value(*a * *
b);
320 this->value_task__output__multi_function_node(socket);
326 void value_task__output__vector_math(
const SocketInContext &socket)
328 const NodeInContext node = socket.owner_node();
332 this->value_task__output__generic_eval(
333 socket, [&](
const Span<InferenceValue>
inputs) -> std::optional<InferenceValue> {
334 const std::optional<float3> a =
inputs[0].get_if_primitive<
float3>();
335 const std::optional<float3>
b =
inputs[1].get_if_primitive<
float3>();
337 return this->make_primitive_inference_value(
float3(0.0f));
339 if (a.has_value() &&
b.has_value()) {
340 return this->make_primitive_inference_value(
float3(*a * *
b));
347 this->value_task__output__generic_eval(
348 socket, [&](
const Span<InferenceValue>
inputs) -> std::optional<InferenceValue> {
349 const std::optional<float3> a =
inputs[0].get_if_primitive<
float3>();
350 const std::optional<float>
scale =
inputs[3].get_if_primitive<
float>();
351 if (a ==
float3(0.0f) || scale == 0.0f) {
352 return this->make_primitive_inference_value(
float3(0.0f));
354 if (a.has_value() &&
scale.has_value()) {
355 return this->make_primitive_inference_value(
float3(*a * *scale));
362 this->value_task__output__multi_function_node(socket);
368 void value_task__output__integer_math(
const SocketInContext &socket)
370 const NodeInContext node = socket.owner_node();
374 this->value_task__output__generic_eval(
375 socket, [&](
const Span<InferenceValue>
inputs) -> std::optional<InferenceValue> {
376 const std::optional<int> a =
inputs[0].get_if_primitive<
int>();
377 const std::optional<int>
b =
inputs[1].get_if_primitive<
int>();
378 if (a == 0 ||
b == 0) {
379 return this->make_primitive_inference_value(0);
381 if (a.has_value() &&
b.has_value()) {
382 return this->make_primitive_inference_value(*a * *
b);
389 this->value_task__output__multi_function_node(socket);
395 void value_task__output__boolean_math(
const SocketInContext &socket)
397 const NodeInContext node = socket.owner_node();
400 const auto handle_binary_op =
401 [&](FunctionRef<std::optional<bool>(std::optional<bool>, std::optional<bool>)> fn) {
402 this->value_task__output__generic_eval(
403 socket, [&](
const Span<InferenceValue>
inputs) -> std::optional<InferenceValue> {
404 const std::optional<bool> a =
inputs[0].get_if_primitive<
bool>();
405 const std::optional<bool>
b =
inputs[1].get_if_primitive<
bool>();
406 const std::optional<bool>
result = fn(a,
b);
408 return this->make_primitive_inference_value(*
result);
416 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
417 if (a ==
false ||
b ==
false) {
420 if (a.has_value() &&
b.has_value()) {
429 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
430 if (a ==
true ||
b ==
true) {
433 if (a.has_value() &&
b.has_value()) {
442 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
443 if (a ==
false ||
b ==
false) {
446 if (a.has_value() &&
b.has_value()) {
455 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
456 if (a ==
true ||
b ==
true) {
459 if (a.has_value() &&
b.has_value()) {
468 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
469 if (a ==
false ||
b ==
true) {
472 if (a.has_value() &&
b.has_value()) {
481 [](
const std::optional<bool> &a,
const std::optional<bool> &
b) -> std::optional<bool> {
482 if (a ==
false ||
b ==
true) {
485 if (a.has_value() &&
b.has_value()) {
493 this->value_task__output__multi_function_node(socket);
499 void value_task__output__warning(
const SocketInContext &socket)
501 const NodeInContext node = socket.owner_node();
502 const SocketInContext show_input_socket = node.input_socket(0);
503 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(show_input_socket);
504 if (!value.has_value()) {
505 this->push_value_task(show_input_socket);
508 all_socket_values_.add_new(socket, *value);
511 void value_task__output__enable_output(
const SocketInContext &socket)
513 const NodeInContext node = socket.owner_node();
514 const SocketInContext enable_input_socket = node.input_socket(0);
515 const SocketInContext value_input_socket = node.input_socket(1);
517 const std::optional<InferenceValue> keep_value = all_socket_values_.lookup_try(
518 enable_input_socket);
519 if (!keep_value.has_value()) {
520 this->push_value_task(enable_input_socket);
523 if (!keep_value->is_primitive_value()) {
527 const bool keep = keep_value->get_primitive<
bool>();
529 const CPPType &type = *socket->typeinfo->base_cpp_type;
533 const std::optional<InferenceValue> value = all_socket_values_.lookup_try(value_input_socket);
534 if (!value.has_value()) {
535 this->push_value_task(value_input_socket);
538 all_socket_values_.add_new(socket, *value);
545 void value_task__output__generic_switch(
546 const SocketInContext &socket,
547 const FunctionRef<
bool(
const SocketInContext &socket, InferenceValue condition)>
550 const NodeInContext node = socket.owner_node();
551 BLI_assert(node->input_sockets().size() >= 1);
552 BLI_assert(node->output_sockets().size() >= 1);
554 const SocketInContext condition_socket{socket.context,
555 get_first_available_bsocket(node->input_sockets())};
556 const std::optional<InferenceValue> condition_value = all_socket_values_.lookup_try(
558 if (!condition_value.has_value()) {
559 this->push_value_task(condition_socket);
562 if (condition_value->is_unknown()) {
567 Vector<const bNodeSocket *> selected_inputs;
568 for (
const int input_i :
569 node->input_sockets().index_range().drop_front(condition_socket->index() + 1))
571 const SocketInContext input_socket = node.input_socket(input_i);
572 if (!input_socket->is_available()) {
575 if (input_socket->type ==
SOCK_CUSTOM &&
STREQ(input_socket->idname,
"NodeSocketVirtual")) {
578 const bool is_selected = is_selected_socket(input_socket, *condition_value);
580 selected_inputs.
append(input_socket.socket);
587 if (selected_inputs.
size() == 1) {
589 const SocketInContext selected_input{socket.context, selected_inputs[0]};
590 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
592 if (!input_value.has_value()) {
593 this->push_value_task(selected_input);
596 all_socket_values_.add_new(socket, *input_value);
601 if (node->typeinfo->build_multi_function) {
603 this->value_task__output__multi_function_node(socket);
610 void value_task__output__generic_eval(
611 const SocketInContext &socket,
612 const FunctionRef<std::optional<InferenceValue>(Span<InferenceValue>
inputs)> eval_fn)
614 const NodeInContext node = socket.owner_node();
615 const int inputs_num = node->input_sockets().size();
618 std::optional<int> next_unknown_input_index;
619 for (
const int input_i : IndexRange(inputs_num)) {
620 const SocketInContext input_socket = node.input_socket(input_i);
621 if (!input_socket->is_available()) {
624 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
626 if (!input_value.has_value()) {
627 next_unknown_input_index = input_i;
630 input_values[input_i] = *input_value;
632 const std::optional<InferenceValue> output_value = eval_fn(input_values);
633 if (output_value.has_value()) {
635 all_socket_values_.add_new(socket, *output_value);
638 if (!next_unknown_input_index.has_value()) {
645 const SocketInContext next_input = node.input_socket(*next_unknown_input_index);
646 this->push_value_task(next_input);
649 void value_task__output__multi_function_node(
const SocketInContext &socket)
651 const NodeInContext node = socket.owner_node();
652 const int inputs_num = node->input_sockets().size();
655 Vector<const void *> input_values(inputs_num);
656 for (
const int input_i : IndexRange(inputs_num)) {
657 const SocketInContext input_socket = node.input_socket(input_i);
658 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(
660 if (!input_value.has_value()) {
661 this->push_value_task(input_socket);
664 if (!input_value->is_primitive_value()) {
668 input_values[input_i] = input_value->get_primitive_ptr();
672 NodeMultiFunctionBuilder builder{*node.node, node->owner_tree()};
673 node->typeinfo->build_multi_function(builder);
674 const mf::MultiFunction &fn = builder.function();
677 const IndexMask
mask(1);
681 for (
const int input_i : IndexRange(inputs_num)) {
682 const SocketInContext input_socket = node.input_socket(input_i);
683 if (!input_socket->is_available()) {
686 params.add_readonly_single_input(
687 GPointer(input_socket->typeinfo->base_cpp_type, input_values[input_i]));
689 for (
const int output_i : node->output_sockets().index_range()) {
690 const SocketInContext output_socket = node.output_socket(output_i);
691 if (!output_socket->is_available()) {
695 const CPPType &base_type = *output_socket->typeinfo->base_cpp_type;
696 void *value = scope_.allocate_owned(base_type);
697 params.add_uninitialized_single_output(GMutableSpan(base_type, value, 1));
706 void value_task__output__muted_node(
const SocketInContext &socket)
708 const NodeInContext node = socket.owner_node();
710 SocketInContext input_socket;
711 for (
const bNodeLink &internal_link : node->internal_links()) {
712 if (internal_link.tosock == socket.socket) {
713 input_socket = SocketInContext{socket.context, internal_link.fromsock};
722 const std::optional<InferenceValue> input_value = all_socket_values_.lookup_try(input_socket);
723 if (!input_value.has_value()) {
724 this->push_value_task(input_socket);
727 const InferenceValue converted_value = this->convert_type_if_necessary(
728 *input_value, *input_socket.socket, *socket.socket);
729 all_socket_values_.add_new(socket, converted_value);
732 void value_task__input(
const SocketInContext &socket)
734 if (socket->is_multi_input()) {
740 const Span<const bNodeLink *> connected_links = socket->directly_linked_links();
741 for (
const bNodeLink *link : connected_links) {
742 if (!link->is_used()) {
745 if (link->fromnode->is_dangling_reroute()) {
752 this->value_task__input__unlinked(socket);
755 this->value_task__input__linked({socket.context, source_link->
fromsock}, socket);
758 void value_task__input__unlinked(
const SocketInContext &socket)
760 if (this->treat_socket_as_unknown(socket)) {
764 if (animated_sockets_.contains(socket.socket)) {
769 if (
const SocketDeclaration *socket_decl = socket.socket->runtime->declaration) {
777 void *value_buffer = scope_.allocate_owned(*socket->typeinfo->base_cpp_type);
778 socket->typeinfo->get_base_cpp_value(socket->default_value, value_buffer);
782 void value_task__input__linked(
const SocketInContext &from_socket,
783 const SocketInContext &to_socket)
785 const std::optional<InferenceValue> from_value = all_socket_values_.lookup_try(from_socket);
786 if (!from_value.has_value()) {
787 this->push_value_task(from_socket);
790 const InferenceValue converted_value = this->convert_type_if_necessary(
791 *from_value, *from_socket.socket, *to_socket.socket);
792 all_socket_values_.add_new(to_socket, converted_value);
795 InferenceValue convert_type_if_necessary(
const InferenceValue &src,
799 if (!src.is_primitive_value()) {
802 const CPPType *from_type = from_socket.
typeinfo->base_cpp_type;
811 if (!conversions.is_convertible(*from_type, *
to_type)) {
814 void *dst = scope_.allocate_owned(*
to_type);
815 conversions.convert_to_uninitialized(*from_type, *
to_type, src.get_primitive_ptr(), dst);
819 bool treat_socket_as_unknown(
const SocketInContext &socket)
const
821 if (!top_level_ignored_inputs_.has_value()) {
824 if (socket.context) {
827 if (socket->is_output()) {
830 return (*top_level_ignored_inputs_)[socket->index_in_all_inputs()];
835 if (!trees_with_handled_animation_data_.add(&
tree)) {
842 static std::regex pattern(R
"#(nodes\["(.*)"\].inputs\[(\d+)\].default_value)#");
843 MultiValueMap<StringRef, int> animated_inputs_by_node_name;
844 auto handle_rna_path = [&](
const char *rna_path) {
846 if (!std::regex_match(rna_path, match, pattern)) {
849 const StringRef node_name{match[1].first, match[1].second - match[1].first};
850 const int socket_index = std::stoi(match[2]);
851 animated_inputs_by_node_name.
add(node_name, socket_index);
855 if (
tree.adt->action) {
857 tree.adt->action->wrap(),
tree.adt->slot_handle, [&](
const FCurve &fcurve) {
858 handle_rna_path(fcurve.rna_path);
863 handle_rna_path(driver->rna_path);
867 if (!animated_inputs_by_node_name.
is_empty()) {
868 for (
const bNode *node :
tree.all_nodes()) {
869 const Span<int> animated_inputs = animated_inputs_by_node_name.
lookup(node->name);
870 const Span<const bNodeSocket *> input_sockets = node->input_sockets();
871 for (
const int socket_index : animated_inputs) {
872 if (socket_index < 0 || socket_index >= input_sockets.
size()) {
877 const bNodeSocket &socket = *input_sockets[socket_index];
878 animated_sockets_.add(&socket);
884 void push_value_task(
const SocketInContext &socket)
886 value_tasks_.push(socket);
889 template<
typename T> InferenceValue make_primitive_inference_value(
const T &value)
895 static const bNodeSocket *get_first_available_bsocket(
const Span<const bNodeSocket *> sockets)
898 if (socket->is_available()) {