Blender V5.0
glsl_preprocess.hh
Go to the documentation of this file.
1/* SPDX-FileCopyrightText: 2024 Blender Authors
2 *
3 * SPDX-License-Identifier: GPL-2.0-or-later */
4
8
9#pragma once
10
11#include <cctype>
12#include <cstdint>
13#include <functional>
14#include <iostream>
15#include <regex>
16#include <sstream>
17#include <string>
18#include <unordered_set>
19#include <vector>
20
21#include "shader_parser.hh"
22
24
25#define ERROR_TOK(token) (token).line_number(), (token).char_number(), (token).line_str()
26
27/* Metadata extracted from shader source file.
28 * These are then converted to their GPU module equivalent. */
29/* TODO(fclem): Make GPU enums standalone and directly use them instead of using separate enums
30 * and types. */
31namespace metadata {
32
33/* Compile-time hashing function which converts string to a 64bit hash. */
34constexpr static uint64_t hash(const char *name)
35{
36 uint64_t hash = 2166136261u;
37 while (*name) {
38 hash = hash * 16777619u;
39 hash = hash ^ *name;
40 ++name;
41 }
42 return hash;
43}
44
45static uint64_t hash(const std::string &name)
46{
47 return hash(name.c_str());
48}
49
51 FragCoord = hash("gl_FragCoord"),
52 FragStencilRef = hash("gl_FragStencilRefARB"),
53 FrontFacing = hash("gl_FrontFacing"),
54 GlobalInvocationID = hash("gl_GlobalInvocationID"),
55 InstanceID = hash("gl_InstanceID"),
56 LocalInvocationID = hash("gl_LocalInvocationID"),
57 LocalInvocationIndex = hash("gl_LocalInvocationIndex"),
58 NumWorkGroup = hash("gl_NumWorkGroup"),
59 PointCoord = hash("gl_PointCoord"),
60 PointSize = hash("gl_PointSize"),
61 PrimitiveID = hash("gl_PrimitiveID"),
62 VertexID = hash("gl_VertexID"),
63 WorkGroupID = hash("gl_WorkGroupID"),
64 WorkGroupSize = hash("gl_WorkGroupSize"),
65 drw_debug = hash("drw_debug_"),
66 printf = hash("printf"),
67 assert = hash("assert"),
68 runtime_generated = hash("runtime_generated"),
69};
70
72 in = hash("in"),
73 out = hash("out"),
74 inout = hash("inout"),
75};
76
77enum Type : uint64_t {
78 float1 = hash("float"),
79 float2 = hash("float2"),
80 float3 = hash("float3"),
81 float4 = hash("float4"),
82 float3x3 = hash("float3x3"),
83 float4x4 = hash("float4x4"),
84 sampler1DArray = hash("sampler1DArray"),
85 sampler2DArray = hash("sampler2DArray"),
86 sampler2D = hash("sampler2D"),
87 sampler3D = hash("sampler3D"),
88 Closure = hash("Closure"),
89};
90
95
97 std::string name;
98 std::vector<ArgumentFormat> arguments;
99};
100
102 uint32_t hash;
103 std::string format;
104};
105
106struct Source {
107 std::vector<Builtin> builtins;
108 /* Note: Could be a set, but for now the order matters. */
109 std::vector<std::string> dependencies;
110 std::vector<PrintfFormat> printf_formats;
111 std::vector<FunctionFormat> functions;
112
113 std::string serialize(const std::string &function_name) const
114 {
115 std::stringstream ss;
116 ss << "static void " << function_name
117 << "(GPUSource &source, GPUFunctionDictionary *g_functions, GPUPrintFormatMap *g_formats) "
118 "{\n";
119 for (auto function : functions) {
120 ss << " {\n";
121 ss << " Vector<metadata::ArgumentFormat> args = {\n";
122 for (auto arg : function.arguments) {
123 ss << " "
124 << "metadata::ArgumentFormat{"
125 << "metadata::Qualifier(" << std::to_string(uint64_t(arg.qualifier)) << "LLU), "
126 << "metadata::Type(" << std::to_string(uint64_t(arg.type)) << "LLU)"
127 << "},\n";
128 }
129 ss << " };\n";
130 ss << " source.add_function(\"" << function.name << "\", args, g_functions);\n";
131 ss << " }\n";
132 }
133 for (auto builtin : builtins) {
134 ss << " source.add_builtin(metadata::Builtin(" << std::to_string(builtin) << "LLU));\n";
135 }
136 for (auto dependency : dependencies) {
137 ss << " source.add_dependency(\"" << dependency << "\");\n";
138 }
139 for (auto format : printf_formats) {
140 ss << " source.add_printf_format(uint32_t(" << std::to_string(format.hash) << "), "
141 << format.format << ", g_formats);\n";
142 }
143 /* Avoid warnings. */
144 ss << " UNUSED_VARS(source, g_functions, g_formats);\n";
145 ss << "}\n";
146 return ss.str();
147 }
148};
149
150} // namespace metadata
151
160 using uint64_t = std::uint64_t;
161 using report_callback = std::function<void(
162 int error_line, int error_char, std::string error_line_string, const char *error_str)>;
163 struct SharedVar {
164 std::string type;
165 std::string name;
166 std::string array;
167 };
168
169 std::vector<SharedVar> shared_vars_;
170
171 metadata::Source metadata;
172
173 using Parser = shader::parser::Parser;
174
175 public:
181 /* Same as GLSL but enable partial C++ feature support like template, references,
182 * include system, etc ... */
184 };
185
186 static SourceLanguage language_from_filename(const std::string &filename)
187 {
188 if (filename.find(".msl") != std::string::npos) {
189 return MSL;
190 }
191 if (filename.find(".glsl") != std::string::npos) {
192 return GLSL;
193 }
194 if (filename.find(".hh") != std::string::npos) {
195 return CPP;
196 }
197 return UNKNOWN;
198 }
199
200 /* Takes a whole source file and output processed source. */
201 std::string process(SourceLanguage language,
202 std::string str,
203 const std::string &filename,
204 bool do_parse_function,
205 bool do_small_type_linting,
206 report_callback report_error,
207 metadata::Source &r_metadata)
208 {
209 if (language == UNKNOWN) {
210 report_error(0, 0, "", "Unknown file type");
211 return "";
212 }
213 str = remove_comments(str, report_error);
214 threadgroup_variables_parsing(str);
215 if (language == BLENDER_GLSL || language == CPP) {
216 str = disabled_code_mutation(str, report_error);
217 }
218 parse_builtins(str, filename);
219 if (language == BLENDER_GLSL || language == CPP) {
220 if (do_parse_function) {
221 parse_library_functions(str);
222 }
223 if (language == BLENDER_GLSL) {
224 pragma_runtime_generated_parsing(str);
225 pragma_once_linting(str, filename, report_error);
226 }
227 str = include_parse_and_remove(str, report_error);
228 str = pragmas_mutation(str, report_error);
229 str = swizzle_function_mutation(str, report_error);
230 str = enum_macro_injection(str, language == CPP, report_error);
231 if (language == BLENDER_GLSL) {
232 Parser parser(str, report_error);
233 using_mutation(parser, report_error);
234
235 namespace_mutation(parser, report_error);
236 template_struct_mutation(parser, report_error);
237 struct_method_mutation(parser, report_error);
238 empty_struct_mutation(parser, report_error);
239 method_call_mutation(parser, report_error);
240 stage_function_mutation(parser, report_error);
241 resource_guard_mutation(parser, report_error);
242 loop_unroll(parser, report_error);
243 assert_processing(parser, filename, report_error);
244 static_strings_merging(parser, report_error);
245 static_strings_parsing_and_mutation(parser, report_error);
246 str = parser.result_get();
247 str = printf_processing(str, report_error);
248 quote_linting(str, report_error);
249 }
250 {
251 Parser parser(str, report_error);
252 global_scope_constant_linting(parser, report_error);
253 if (do_small_type_linting) {
254 small_type_linting(parser, report_error);
255 }
256 remove_quotes(parser, report_error);
257 argument_reference_mutation(parser, report_error);
258 default_argument_mutation(parser, report_error);
259 str = parser.result_get();
260 }
261 str = variable_reference_mutation(str, report_error);
262 str = template_definition_mutation(str, report_error);
263 if (language == BLENDER_GLSL) {
264 str = namespace_separator_mutation(str);
265 }
266 str = template_call_mutation(str, report_error);
267 }
268 else if (language == MSL) {
269 pragma_runtime_generated_parsing(str);
270 str = include_parse_and_remove(str, report_error);
271 str = pragmas_mutation(str, report_error);
272 }
273#ifdef __APPLE__ /* Limiting to Apple hardware since GLSL compilers might have issues. */
274 if (language == GLSL) {
275 str = matrix_constructor_mutation(str);
276 }
277#endif
278 str = argument_decorator_macro_injection(str);
279 str = array_constructor_macro_injection(str);
280 r_metadata = metadata;
281 return line_directive_prefix(filename) + str + threadgroup_variables_suffix();
282 }
283
284 /* Variant use for python shaders. */
285 std::string process(const std::string &str)
286 {
287 auto no_err_report = [](int, int, std::string, const char *) {};
288 metadata::Source unused;
289 return process(GLSL, str, "", false, false, no_err_report, unused);
290 }
291
292 private:
293 using regex_callback = std::function<void(const std::smatch &)>;
294 using regex_callback_with_line_count = std::function<void(const std::smatch &, int64_t)>;
295
296 /* Helper to make the code more readable in parsing functions. */
297 void regex_global_search(const std::string &str,
298 const std::regex &regex,
299 regex_callback callback)
300 {
301 using namespace std;
302 string::const_iterator it = str.begin();
303 for (smatch match; regex_search(it, str.end(), match, regex); it = match.suffix().first) {
304 callback(match);
305 }
306 }
307
308 void regex_global_search(const std::string &str,
309 const std::regex &regex,
310 regex_callback_with_line_count callback)
311 {
312 using namespace std;
313 int64_t line = 1;
314 regex_global_search(str, regex, [&line, &callback](const std::smatch &match) {
315 line += line_count(match.prefix().str());
316 callback(match, line);
317 line += line_count(match[0].str());
318 });
319 }
320
321 template<typename ReportErrorF>
322 std::string remove_comments(const std::string &str, const ReportErrorF &report_error)
323 {
324 std::string out_str = str;
325 {
326 /* Multi-line comments. */
327 size_t start, end = 0;
328 while ((start = out_str.find("/*", end)) != std::string::npos) {
329 end = out_str.find("*/", start + 2);
330 if (end == std::string::npos) {
331 break;
332 }
333 for (size_t i = start; i < end + 2; ++i) {
334 if (out_str[i] != '\n') {
335 out_str[i] = ' ';
336 }
337 }
338 }
339
340 if (end == std::string::npos) {
341 report_error(line_number(out_str, start),
342 char_number(out_str, start),
343 line_str(out_str, start),
344 "Malformed multi-line comment.");
345 return out_str;
346 }
347 }
348 {
349 /* Single-line comments. */
350 size_t start, end = 0;
351 while ((start = out_str.find("//", end)) != std::string::npos) {
352 end = out_str.find('\n', start + 2);
353 if (end == std::string::npos) {
354 break;
355 }
356 for (size_t i = start; i < end; ++i) {
357 out_str[i] = ' ';
358 }
359 }
360
361 if (end == std::string::npos) {
362 report_error(line_number(out_str, start),
363 char_number(out_str, start),
364 line_str(out_str, start),
365 "Malformed single line comment, missing newline.");
366 return out_str;
367 }
368 }
369 /* Remove trailing white space as they make the subsequent regex much slower. */
370 std::regex regex(R"((\ )*?\n)");
371 return std::regex_replace(out_str, regex, "\n");
372 }
373
374 static std::string template_arguments_mangle(const shader::parser::Scope template_args)
375 {
376 using namespace std;
377 using namespace shader::parser;
378 string args_concat;
379 template_args.foreach_scope(ScopeType::TemplateArg, [&](const Scope &scope) {
380 args_concat += 'T' + scope.start().str();
381 });
382 return args_concat;
383 }
384
385 void template_struct_mutation(Parser &parser, report_callback &report_error)
386 {
387 using namespace std;
388 using namespace shader::parser;
389 {
390 parser.foreach_match("w<..>(..)", [&](const vector<Token> &tokens) {
391 const Scope template_args = tokens[1].scope();
392 template_args.foreach_match("w<..>", [&parser](const vector<Token> &tokens) {
393 parser.replace(tokens[1].scope(), template_arguments_mangle(tokens[1].scope()), true);
394 });
395 });
396 parser.apply_mutations();
397
398 /* Replace full specialization by simple struct. */
399 parser.foreach_match("t<>sw<..>", [&](const std::vector<Token> &tokens) {
400 parser.erase(tokens[0], tokens[2]);
401 parser.replace(tokens[5].scope(), template_arguments_mangle(tokens[5].scope()), true);
402 });
403 parser.apply_mutations();
404 }
405 {
406 parser.foreach_scope(ScopeType::Template, [&](Scope temp) {
407 /* Parse template declaration. */
408 Token struct_start = temp.end().next();
409 if (struct_start != Struct) {
410 return;
411 }
412 Token struct_name = struct_start.next();
413 Scope struct_body = struct_name.next().scope();
414
415 bool error = false;
416 temp.foreach_match("=", [&](const std::vector<Token> &tokens) {
417 report_error(ERROR_TOK(tokens[0]),
418 "Default arguments are not supported inside template declaration");
419 error = true;
420 });
421 if (error) {
422 return;
423 }
424
425 string arg_pattern;
426 vector<string> arg_list;
427 temp.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) {
428 const Token type = arg.start();
429 const Token name = type.next();
430 const string name_str = name.str();
431 const string type_str = type.str();
432
433 arg_list.emplace_back(name_str);
434
435 if (type_str == "typename") {
436 arg_pattern += ",w";
437 }
438 else if (type_str == "enum" || type_str == "bool") {
439 arg_pattern += ",w";
440 }
441 else if (type_str == "int" || type_str == "uint") {
442 arg_pattern += ",0";
443 }
444 else {
445 report_error(ERROR_TOK(type), "Invalid template argument type");
446 }
447 });
448
449 Token struct_end = struct_body.end();
450 const string fn_decl = parser.substr_range_inclusive(struct_start.str_index_start(),
451 struct_end.str_index_last());
452
453 /* Remove declaration. */
454 Token template_keyword = temp.start().prev();
455 parser.erase(template_keyword.str_index_start(), struct_end.line_end());
456
457 /* Replace instantiations. */
458 Scope parent_scope = temp.scope();
459 string specialization_pattern = "tsw<" + arg_pattern.substr(1) + ">";
460 parent_scope.foreach_match(specialization_pattern, [&](const std::vector<Token> &tokens) {
461 if (struct_name.str() != tokens[2].str()) {
462 return;
463 }
464 /* Parse template values. */
465 vector<pair<string, string>> arg_name_value_pairs;
466 for (int i = 0; i < arg_list.size(); i++) {
467 arg_name_value_pairs.emplace_back(arg_list[i], tokens[4 + 2 * i].str());
468 }
469 /* Specialize template content. */
470 Parser instance_parser(fn_decl, report_error, true);
471 instance_parser.foreach_match("w", [&](const std::vector<Token> &tokens) {
472 string token_str = tokens[0].str();
473 for (const auto &arg_name_value : arg_name_value_pairs) {
474 if (token_str == arg_name_value.first) {
475 instance_parser.replace(tokens[0], arg_name_value.second);
476 }
477 }
478 });
479
480 const string template_args = parser.substr_range_inclusive(
481 tokens[3], tokens[3 + arg_pattern.size()]);
482 size_t pos = fn_decl.find(" " + struct_name.str());
483 instance_parser.insert_after(pos + struct_name.str().size(), template_args);
484 /* Paste template content in place of instantiation. */
485 Token end_of_instantiation = tokens.back();
486 string instance = instance_parser.result_get();
487 parser.insert_line_number(tokens.front().str_index_start() - 1,
488 struct_start.line_number());
489 parser.replace(tokens.front().str_index_start(),
490 end_of_instantiation.str_index_last_no_whitespace(),
491 instance);
492 parser.insert_line_number(end_of_instantiation.line_end() + 1,
493 end_of_instantiation.line_number() + 1);
494 });
495 });
496 parser.apply_mutations();
497 }
498 {
499 /* This rely on our codestyle that do not put spaces between template name and the opening
500 * angle bracket. */
501 parser.foreach_match("sw<..>", [&](const std::vector<Token> &tokens) {
502 parser.replace(tokens[2].scope(), template_arguments_mangle(tokens[2].scope()), true);
503 });
504 parser.apply_mutations();
505 }
506 }
507
508 std::string template_definition_mutation(const std::string &str, report_callback &report_error)
509 {
510 if (str.find("template") == std::string::npos) {
511 return str;
512 }
513
514 using namespace std;
515 using namespace shader::parser;
516
517 std::string out_str = str;
518
519 Parser parser(out_str, report_error);
520
521 auto process_specialization = [&](const Token specialization_start,
522 const Scope template_args) {
523 parser.erase(specialization_start, specialization_start.next().next());
524 parser.replace(template_args, template_arguments_mangle(template_args), true);
525 };
526
527 parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
528 /* Replace full specialization by simple functions. */
529 scope.foreach_match("t<>ww<", [&](const std::vector<Token> &tokens) {
530 process_specialization(tokens[0], tokens[5].scope());
531 });
532 scope.foreach_match("t<>ww::w<", [&](const std::vector<Token> &tokens) {
533 process_specialization(tokens[0], tokens[8].scope());
534 });
535 });
536
537 parser.apply_mutations();
538
539 auto process_template = [&](const Token fn_start,
540 const string &fn_name,
541 const Scope fn_args,
542 const Scope temp,
543 const Token fn_end) {
544 bool error = false;
545 temp.foreach_match("=", [&](const std::vector<Token> &tokens) {
546 report_error(tokens[0].line_number(),
547 tokens[0].char_number(),
548 tokens[0].line_str(),
549 "Default arguments are not supported inside template declaration");
550 error = true;
551 });
552 if (error) {
553 return;
554 }
555
556 string arg_pattern;
557 vector<string> arg_list;
558 bool all_template_args_in_function_signature = true;
559 temp.foreach_scope(ScopeType::TemplateArg, [&](Scope arg) {
560 const Token type = arg.start();
561 const Token name = type.next();
562 const string name_str = name.str();
563 const string type_str = type.str();
564
565 arg_list.emplace_back(name_str);
566
567 if (type_str == "typename") {
568 arg_pattern += ",w";
569 bool found = false;
570 /* Search argument list for typenames. If typename matches, the template argument is
571 * present inside the function signature. */
572 fn_args.foreach_match("ww", [&](const std::vector<Token> &tokens) {
573 if (tokens[0].str() == name_str) {
574 found = true;
575 }
576 });
577 all_template_args_in_function_signature &= found;
578 }
579 else if (type_str == "enum" || type_str == "bool") {
580 arg_pattern += ",w";
581 /* Values cannot be resolved using type deduction. */
582 all_template_args_in_function_signature = false;
583 }
584 else if (type_str == "int" || type_str == "uint") {
585 arg_pattern += ",0";
586 /* Values cannot be resolved using type deduction. */
587 all_template_args_in_function_signature = false;
588 }
589 else {
590 report_error(ERROR_TOK(type), "Invalid template argument type");
591 }
592 });
593
594 const string fn_decl = parser.substr_range_inclusive(fn_start.str_index_start(),
595 fn_end.line_end());
596
597 /* Remove declaration. */
598 Token template_keyword = temp.start().prev();
599 parser.erase(template_keyword.str_index_start(), fn_end.line_end());
600
601 auto process_instantiation = [&](const string &inst_name,
602 const Token inst_start,
603 const Token inst_end,
604 const Scope &inst_args) {
605 if (fn_name != inst_name) {
606 return;
607 }
608 /* Parse template values. */
609 vector<pair<string, string>> arg_name_value_pairs;
610 for (int i = 0; i < arg_list.size(); i++) {
611 arg_name_value_pairs.emplace_back(arg_list[i], inst_args[1 + 2 * i].str());
612 }
613 /* Specialize template content. */
614 Parser instance_parser(fn_decl, report_error, true);
615 instance_parser.foreach_token(Word, [&](const Token &word) {
616 string token_str = word.str();
617 for (const auto &arg_name_value : arg_name_value_pairs) {
618 if (token_str == arg_name_value.first) {
619 instance_parser.replace(word, arg_name_value.second);
620 }
621 }
622 });
623
624 if (!all_template_args_in_function_signature) {
625 /* Append template args after function name.
626 * `void func() {}` > `void func<a, 1>() {}`. */
627 size_t pos = fn_decl.find(" " + fn_name);
628 instance_parser.insert_after(pos + fn_name.size(), inst_args.str());
629 }
630 /* Paste template content in place of instantiation. */
631 string instance = instance_parser.result_get();
632 parser.insert_line_number(inst_start.str_index_start() - 1, fn_start.line_number());
633 parser.replace(
634 inst_start.str_index_start(), inst_end.str_index_last_no_whitespace(), instance);
635 parser.insert_line_number(inst_end.line_end() + 1, inst_end.line_number() + 1);
636 };
637
638 /* Replace instantiations. */
639 Scope parent_scope = temp.scope();
640 {
641 string specialization_pattern = "tww<" + arg_pattern.substr(1) + ">(..);";
642 parent_scope.foreach_match(specialization_pattern, [&](const vector<Token> &tokens) {
643 process_instantiation(tokens[2].str(), tokens.front(), tokens.back(), tokens[3].scope());
644 });
645 }
646 {
647 string specialization_pattern = "tww::w<" + arg_pattern.substr(1) + ">(..);";
648 parent_scope.foreach_match(specialization_pattern, [&](const vector<Token> &tokens) {
649 const string inst_name = parser.substr_range_inclusive(tokens[2], tokens[5]);
650 process_instantiation(inst_name, tokens.front(), tokens.back(), tokens[6].scope());
651 });
652 }
653 };
654
655 parser.foreach_match("t<..>ww(..)c?{..}", [&](const vector<Token> &tokens) {
656 process_template(
657 tokens[5], tokens[6].str(), tokens[7].scope(), tokens[1].scope(), tokens[16]);
658 });
659
660 parser.foreach_match("t<..>ww::w(..)c?{..}", [&](const vector<Token> &tokens) {
661 const string fn_name = parser.substr_range_inclusive(tokens[6], tokens[9]);
662 process_template(tokens[5], fn_name, tokens[10].scope(), tokens[1].scope(), tokens[19]);
663 });
664
665 out_str = parser.result_get();
666
667 {
668 /* Check if there is no remaining declaration and instantiation that were not processed. */
669 size_t error_pos;
670 if ((error_pos = out_str.find("template<")) != std::string::npos) {
671 report_error(line_number(out_str, error_pos),
672 char_number(out_str, error_pos),
673 line_str(out_str, error_pos),
674 "Template declaration unsupported syntax");
675 }
676 if ((error_pos = out_str.find("template ")) != std::string::npos) {
677 report_error(line_number(out_str, error_pos),
678 char_number(out_str, error_pos),
679 line_str(out_str, error_pos),
680 "Template instantiation unsupported syntax");
681 }
682 }
683 return out_str;
684 }
685
686 std::string template_call_mutation(const std::string &str, report_callback &report_error)
687 {
688 using namespace std;
689 using namespace shader::parser;
690
691 Parser parser(str, report_error);
692 parser.foreach_match("w<..>", [&](const std::vector<Token> &tokens) {
693 parser.replace(tokens[1].scope(), template_arguments_mangle(tokens[1].scope()), true);
694 });
695 return parser.result_get();
696 }
697
698 /* Remove remaining quotes that can be found in some unsupported C++ macros. */
699 void remove_quotes(Parser &parser, report_callback /*report_error*/)
700 {
701 using namespace std;
702 using namespace shader::parser;
703
704 parser.foreach_token(TokenType::String, [&](const Token token) { parser.erase(token); });
705 parser.apply_mutations();
706 }
707
708 std::string include_parse_and_remove(const std::string &str, report_callback report_error)
709 {
710 using namespace std;
711 using namespace shader::parser;
712
713 Parser parser(str, report_error);
714
715 parser.foreach_match("#w_", [&](const std::vector<Token> &tokens) {
716 if (tokens[1].str() != "include") {
717 return;
718 }
719 string dependency_name = tokens[2].str_exclusive();
720 /* Assert that includes are at the top of the file. */
721 if (dependency_name == "gpu_shader_compat.hh") {
722 /* Skip GLSL-C++ stubs. They are only for IDE linting. */
723 parser.erase(tokens.front(), tokens.back());
724 return;
725 }
726 if (dependency_name.find("infos.hh") != std::string::npos) {
727 /* Skip info files. They are only for IDE linting. */
728 parser.erase(tokens.front(), tokens.back());
729 return;
730 }
731 if (dependency_name.find("gpu_shader_create_info.hh") != std::string::npos) {
732 /* Skip info files. They are only for IDE linting. */
733 parser.erase(tokens.front(), tokens.back());
734 return;
735 }
736 metadata.dependencies.emplace_back(dependency_name);
737 parser.erase(tokens.front(), tokens.back());
738 });
739
740 return parser.result_get();
741 }
742
743 void pragma_runtime_generated_parsing(const std::string &str)
744 {
745 if (str.find("\n#pragma runtime_generated") != std::string::npos) {
746 metadata.builtins.emplace_back(metadata::Builtin::runtime_generated);
747 }
748 }
749
750 void pragma_once_linting(const std::string &str,
751 const std::string &filename,
752 report_callback report_error)
753 {
754 if (filename.find("_lib.") == std::string::npos && filename.find(".hh") == std::string::npos) {
755 return;
756 }
757 if (str.find("\n#pragma once") == std::string::npos) {
758 report_error(0, 0, "", "Header files must contain #pragma once directive.");
759 }
760 }
761
762 void loop_unroll(Parser &parser, report_callback report_error)
763 {
764 using namespace std;
765 using namespace shader::parser;
766
767 auto parse_for_args =
768 [&](const Scope loop_args, Scope &r_init, Scope &r_condition, Scope &r_iter) {
769 r_init = r_condition = r_iter = Scope::invalid();
770 loop_args.foreach_scope(ScopeType::LoopArg, [&](const Scope arg) {
771 if (arg.start().prev() == '(' && arg.end().next() == ';') {
772 r_init = arg;
773 }
774 else if (arg.start().prev() == ';' && arg.end().next() == ';') {
775 r_condition = arg;
776 }
777 else if (arg.start().prev() == ';' && arg.end().next() == ')') {
778 r_iter = arg;
779 }
780 else {
781 report_error(ERROR_TOK(arg.start()), "Invalid loop declaration.");
782 }
783 });
784 };
785
786 auto process_loop = [&](const Token loop_start,
787 const int iter_count,
788 const int iter_init,
789 const int iter_incr,
790 const bool condition_is_trivial,
791 const bool iteration_is_trivial,
792 const Scope init,
793 const Scope cond,
794 const Scope iter,
795 const Scope body,
796 const string body_prefix = "",
797 const string body_suffix = "") {
798 /* Check that there is no unsupported keywords in the loop body. */
799 bool error = false;
800 /* Checks if `continue` exists, even in switch statement inside the unrolled loop. */
801 body.foreach_token(Continue, [&](const Token token) {
802 if (token.first_containing_scope_of_type(ScopeType::LoopBody) == body) {
803 report_error(ERROR_TOK(token), "Unrolled loop cannot contain \"continue\" statement.");
804 error = true;
805 }
806 });
807 /* Checks if `break` exists directly the unrolled loop scope. Switch statements are ok. */
808 body.foreach_token(Break, [&](const Token token) {
809 if (token.first_containing_scope_of_type(ScopeType::LoopBody) == body) {
810 const Scope switch_scope = token.first_containing_scope_of_type(ScopeType::SwitchBody);
811 if (switch_scope.is_invalid() || !body.contains(switch_scope)) {
812 report_error(ERROR_TOK(token), "Unrolled loop cannot contain \"break\" statement.");
813 error = true;
814 }
815 }
816 });
817 if (error) {
818 return;
819 }
820
821 if (!parser.replace_try(loop_start, body.end(), "", true)) {
822 /* This is the case of nested loops. This loop will be processed in another parser pass. */
823 return;
824 }
825
826 string indent_init, indent_cond, indent_iter;
827 if (init.is_valid()) {
828 indent_init = string(init.start().char_number() - 1, ' ');
829 }
830 if (cond.is_valid()) {
831 indent_cond = string(cond.start().char_number() - 3, ' ');
832 }
833 if (iter.is_valid()) {
834 indent_iter = string(iter.start().char_number(), ' ');
835 }
836 string indent_body = string(body.start().char_number(), ' ');
837 string indent_end = string(body.end().char_number(), ' ');
838
839 /* If possible, replaces the index of the loop iteration inside the given string. */
840 auto replace_index = [&](const string &str, int loop_index) {
841 if (iter.is_invalid() || !iteration_is_trivial || str.empty()) {
842 return str;
843 }
844 Parser str_parser(str, report_error);
845 str_parser.foreach_token(Word, [&](const Token tok) {
846 if (tok.str() == iter[0].str()) {
847 str_parser.replace(tok, std::to_string(loop_index), true);
848 }
849 });
850 return str_parser.result_get();
851 };
852
853 parser.insert_after(body.end(), "\n");
854 if (init.is_valid() && !iteration_is_trivial) {
855 parser.insert_line_number(body.end(), init.start().line_number());
856 parser.insert_after(body.end(), indent_init + "{" + init.str() + ";\n");
857 }
858 else {
859 parser.insert_after(body.end(), "{\n");
860 }
861 for (int64_t i = 0, value = iter_init; i < iter_count; i++, value += iter_incr) {
862 if (cond.is_valid() && !condition_is_trivial) {
863 parser.insert_line_number(body.end(), cond.start().line_number());
864 parser.insert_after(body.end(), indent_cond + "if(" + cond.str() + ")\n");
865 }
866 parser.insert_after(body.end(), replace_index(body_prefix, value));
867 parser.insert_line_number(body.end(), body.start().line_number());
868 parser.insert_after(body.end(), indent_body + replace_index(body.str(), value) + "\n");
869 parser.insert_after(body.end(), body_suffix);
870 if (iter.is_valid() && !iteration_is_trivial) {
871 parser.insert_line_number(body.end(), iter.start().line_number());
872 parser.insert_after(body.end(), indent_iter + iter.str() + ";\n");
873 }
874 }
875 parser.insert_line_number(body.end(), body.end().line_number());
876 parser.insert_after(body.end(), indent_end + body.end().str_with_whitespace());
877 };
878
879 do {
880 /* [[gpu::unroll]]. */
881 parser.foreach_match("[[w::w]]f(..){..}", [&](const std::vector<Token> tokens) {
882 if (tokens[1].scope().str() != "[gpu::unroll]") {
883 return;
884 }
885 const Token for_tok = tokens[8];
886 const Scope loop_args = tokens[9].scope();
887 const Scope loop_body = tokens[13].scope();
888
889 Scope init, cond, iter;
890 parse_for_args(loop_args, init, cond, iter);
891
892 /* Init statement. */
893 const Token var_type = init[0];
894 const Token var_name = init[1];
895 const Token var_init = init[2];
896 if (var_type.str() != "int" && var_type.str() != "uint") {
897 report_error(ERROR_TOK(var_init), "Can only unroll integer based loop.");
898 return;
899 }
900 if (var_init != '=') {
901 report_error(ERROR_TOK(var_init), "Expecting assignment here.");
902 return;
903 }
904 if (init[3] != '0' && init[3] != '-') {
905 report_error(ERROR_TOK(init[3]), "Expecting integer literal here.");
906 return;
907 }
908
909 /* Conditional statement. */
910 const Token cond_var = cond[0];
911 const Token cond_type = cond[1];
912 const Token cond_sign = (cond[2] == '+' || cond[2] == '-') ? cond[2] : Token::invalid();
913 const Token cond_end = cond_sign.is_valid() ? cond[3] : cond[2];
914 if (cond_var.str() != var_name.str()) {
915 report_error(ERROR_TOK(cond_var), "Non matching loop counter variable.");
916 return;
917 }
918 if (cond_end != '0') {
919 report_error(ERROR_TOK(cond_end), "Expecting integer literal here.");
920 return;
921 }
922
923 /* Iteration statement. */
924 const Token iter_var = iter[0];
925 const Token iter_type = iter[1];
926 const Token iter_end = iter[1];
927 int iter_incr = 0;
928 if (iter_var.str() != var_name.str()) {
929 report_error(ERROR_TOK(iter_var), "Non matching loop counter variable.");
930 return;
931 }
932 if (iter_type == Increment) {
933 iter_incr = +1;
934 if (cond_type == '>') {
935 report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
936 return;
937 }
938 }
939 else if (iter_type == Decrement) {
940 iter_incr = -1;
941 if (cond_type == '<') {
942 report_error(ERROR_TOK(for_tok), "Unsupported condition in unrolled loop.");
943 return;
944 }
945 }
946 else {
947 report_error(ERROR_TOK(iter_type), "Unsupported loop expression. Expecting ++ or --.");
948 return;
949 }
950
951 int64_t init_value = std::stol(
952 parser.substr_range_inclusive(var_init.next(), var_init.scope().end()));
953 int64_t end_value = std::stol(
954 parser.substr_range_inclusive(cond_sign.is_valid() ? cond_sign : cond_end, cond_end));
955 /* TODO(fclem): Support arbitrary strides (aka, arbitrary iter statement). */
956 int iter_count = std::abs(end_value - init_value);
957 if (cond_type == GEqual || cond_type == LEqual) {
958 iter_count += 1;
959 }
960
961 bool condition_is_trivial = (cond_end == cond.end());
962 bool iteration_is_trivial = (iter_end == iter.end());
963
964 process_loop(tokens[0],
965 iter_count,
966 init_value,
967 iter_incr,
968 condition_is_trivial,
969 iteration_is_trivial,
970 init,
971 cond,
972 iter,
973 loop_body);
974 });
975
976 /* [[gpu::unroll(n)]]. */
977 parser.foreach_match("[[w::w(0)]]f(..){..}", [&](const std::vector<Token> tokens) {
978 if (tokens[5].str() != "unroll") {
979 return;
980 }
981 const Scope loop_args = tokens[12].scope();
982 const Scope loop_body = tokens[16].scope();
983
984 Scope init, cond, iter;
985 parse_for_args(loop_args, init, cond, iter);
986
987 int iter_count = std::stol(tokens[7].str());
988
989 process_loop(tokens[0], iter_count, 0, 0, false, false, init, cond, iter, loop_body);
990 });
991
992 /* [[gpu::unroll_define(max_n)]]. */
993 parser.foreach_match("[[w::w(0)]]f(..){..}", [&](const std::vector<Token> tokens) {
994 if (tokens[5].str() != "unroll_define") {
995 return;
996 }
997 const Scope loop_args = tokens[12].scope();
998 const Scope loop_body = tokens[16].scope();
999
1000 /* Validate format. */
1001 Token define_name = Token::invalid();
1002 Token iter_var = Token::invalid();
1003 loop_args.foreach_match("ww=0;w<w;wP", [&](const std::vector<Token> tokens) {
1004 if (tokens[1].str() != tokens[5].str() || tokens[5].str() != tokens[9].str()) {
1005 return;
1006 }
1007 iter_var = tokens[1];
1008 define_name = tokens[7];
1009 });
1010
1011 if (define_name.is_invalid()) {
1012 report_error(ERROR_TOK(loop_args.start()),
1013 "Incompatible loop format for [[gpu::unroll_define(max_n)]], expected "
1014 "'(int i = 0; i < DEFINE; i++)'");
1015 return;
1016 }
1017
1018 Scope init, cond, iter;
1019 parse_for_args(loop_args, init, cond, iter);
1020
1021 int iter_count = std::stol(tokens[7].str());
1022
1023 string body_prefix = "#if " + define_name.str() + " > " + iter_var.str() + "\n";
1024
1025 process_loop(tokens[0],
1026 iter_count,
1027 0,
1028 1,
1029 true,
1030 true,
1031 init,
1032 cond,
1033 iter,
1034 loop_body,
1035 body_prefix,
1036 "#endif\n");
1037 });
1038 } while (parser.apply_mutations());
1039
1040 /* Check for remaining keywords. */
1041 parser.foreach_match("[[w::w", [&](const std::vector<Token> tokens) {
1042 if (tokens[2].str() == "gpu" && tokens[5].str() == "unroll") {
1043 report_error(ERROR_TOK(tokens[0]), "Incompatible loop format for [[gpu::unroll]].");
1044 }
1045 });
1046 }
1047
1048 void namespace_mutation(Parser &parser, report_callback report_error)
1049 {
1050 using namespace std;
1051 using namespace shader::parser;
1052
1053 /* Parse each namespace declaration. */
1054 parser.foreach_scope(ScopeType::Namespace, [&](const Scope &scope) {
1055 /* TODO(fclem): This could be supported using multiple passes. */
1056 scope.foreach_match("n", [&](const std::vector<Token> &tokens) {
1057 report_error(ERROR_TOK(tokens[0]), "Nested namespaces are unsupported.");
1058 });
1059
1060 string namespace_prefix = namespace_separator_mutation(
1061 scope.start().prev().full_symbol_name() + "::");
1062 auto process_symbol = [&](const Token &symbol) {
1063 if (symbol.next() == '<') {
1064 /* Template instantiation or specialization. */
1065 return;
1066 }
1067 /* Replace all occurrences of the non-namespace specified symbol. */
1068 scope.foreach_token(Word, [&](const Token &token) {
1069 if (token.str() != symbol.str()) {
1070 return;
1071 }
1072 /* Reject symbols that already have namespace specified. */
1073 if (token.namespace_start() != token) {
1074 return;
1075 }
1076 /* Reject method calls. */
1077 if (token.prev() == '.') {
1078 return;
1079 }
1080 parser.replace(token, namespace_prefix + token.str(), true);
1081 });
1082 };
1083
1084 unordered_set<string> processed_functions;
1085
1086 scope.foreach_function([&](bool, Token, Token fn_name, Scope, bool, Scope) {
1087 /* Note: Struct scopes are currently parsed as Local. */
1088 if (fn_name.scope().type() == ScopeType::Local) {
1089 /* Don't process functions inside a struct scope as the namespace must not be apply
1090 * to them, but to the type. Otherwise, method calls will not work. */
1091 return;
1092 }
1093 if (processed_functions.count(fn_name.str())) {
1094 /* Don't process function names twice. Can happen with overloads. */
1095 return;
1096 }
1097 processed_functions.emplace(fn_name.str());
1098 process_symbol(fn_name);
1099 });
1100 scope.foreach_struct([&](Token, Token struct_name, Scope) { process_symbol(struct_name); });
1101
1102 Token namespace_tok = scope.start().prev().namespace_start().prev();
1103 if (namespace_tok == Namespace) {
1104 parser.erase(namespace_tok, scope.start());
1105 parser.erase(scope.end());
1106 }
1107 else {
1108 report_error(ERROR_TOK(namespace_tok), "Expected namespace token.");
1109 }
1110 });
1111
1112 parser.apply_mutations();
1113 }
1114
1115 /* Needs to run before namespace mutation so that `using` have more precedence. */
1116 void using_mutation(Parser &parser, report_callback report_error)
1117 {
1118 using namespace std;
1119 using namespace shader::parser;
1120
1121 parser.foreach_match("un", [&](const std::vector<Token> &tokens) {
1122 report_error(ERROR_TOK(tokens[0]),
1123 "Unsupported `using namespace`. "
1124 "Add individual `using` directives for each needed symbol.");
1125 });
1126
1127 auto process_using = [&](const Token &using_tok,
1128 const Token &from,
1129 const Token &to_start,
1130 const Token &to_end,
1131 const Token &end_tok) {
1132 string to = parser.substr_range_inclusive(to_start, to_end);
1133 string namespace_prefix = parser.substr_range_inclusive(to_start,
1134 to_end.prev().prev().prev());
1135 Scope scope = from.scope();
1136
1137 /* Using the keyword in global or at namespace scope. */
1138 if (scope.type() == ScopeType::Global) {
1139 report_error(ERROR_TOK(using_tok), "The `using` keyword is not allowed in global scope.");
1140 return;
1141 }
1142 if (scope.type() == ScopeType::Namespace) {
1143 /* Ensure we are bringing symbols from the same namespace.
1144 * Otherwise we can have different shadowing outcome between shader and C++. */
1145 string namespace_name = scope.start().prev().full_symbol_name();
1146 if (namespace_name != namespace_prefix) {
1147 report_error(
1148 ERROR_TOK(using_tok),
1149 "The `using` keyword is only allowed in namespace scope to make visible symbols "
1150 "from the same namespace declared in another scope, potentially from another "
1151 "file.");
1152 return;
1153 }
1154 }
1155
1156 to = namespace_separator_mutation(to);
1157
1158 /* Assignments do not allow to alias functions symbols. */
1159 const bool use_alias = from.str() != to_end.str();
1160 const bool replace_fn = !use_alias;
1164
1165 /* Replace all occurrences of the non-namespace specified symbol. */
1166 scope.foreach_token(Word, [&](const Token &token) {
1167 /* Do not replace symbols before the using statement. */
1168 if (token.index <= to_end.index) {
1169 return;
1170 }
1171 /* Reject symbols that contain the target symbol name. */
1172 if (token.prev() == ':') {
1173 return;
1174 }
1175 if (!replace_fn && token.next() == '(') {
1176 return;
1177 }
1178 if (token.str() != from.str()) {
1179 return;
1180 }
1181 parser.replace(token, to, true);
1182 });
1183
1184 parser.erase(using_tok, end_tok);
1185 };
1186
1187 parser.foreach_match("uw::w", [&](const std::vector<Token> &tokens) {
1188 Token end = tokens.back().find_next(SemiColon);
1189 process_using(tokens[0], end.prev(), tokens[1], end.prev(), end);
1190 });
1191
1192 parser.foreach_match("uw=w::w", [&](const std::vector<Token> &tokens) {
1193 Token end = tokens.back().find_next(SemiColon);
1194 process_using(tokens[0], tokens[1], tokens[3], end.prev(), end);
1195 });
1196
1197 parser.apply_mutations();
1198
1199 /* Verify all using were processed. */
1200 parser.foreach_token(Using, [&](const Token &token) {
1201 report_error(ERROR_TOK(token), "Unsupported `using` keyword usage.");
1202 });
1203 }
1204
1205 std::string namespace_separator_mutation(const std::string &str)
1206 {
1207 std::string out = str;
1208
1209 /* Global namespace reference. */
1210 replace_all(out, " ::", " ");
1211 /* Specific namespace reference.
1212 * Cannot use `__` because of some compilers complaining about reserved symbols. */
1213 replace_all(out, "::", "_");
1214 return out;
1215 }
1216
1217 std::string disabled_code_mutation(const std::string &str, report_callback &report_error)
1218 {
1219 using namespace std;
1220 using namespace shader::parser;
1221
1222 Parser parser(str, report_error);
1223
1224 auto process_disabled_scope = [&](Token start_tok) {
1225 /* Search for endif with the same indentation. Assume formatted input. */
1226 string end_str = start_tok.str_with_whitespace() + "endif";
1227 size_t scope_end = parser.data_get().str.find(end_str, start_tok.str_index_start());
1228 if (scope_end == string::npos) {
1229 report_error(ERROR_TOK(start_tok), "Couldn't find end of disabled scope.");
1230 return;
1231 }
1232 /* Search for else/elif with the same indentation. Assume formatted input. */
1233 string else_str = start_tok.str_with_whitespace() + "el";
1234 size_t scope_else = parser.data_get().str.find(else_str, start_tok.str_index_start());
1235 if (scope_else != string::npos && scope_else < scope_end) {
1236 /* Only erase the content and keep the preprocessor directives. */
1237 parser.erase(start_tok.line_end() + 1, scope_else - 1);
1238 }
1239 else {
1240 /* Erase the content and the preprocessor directives. */
1241 parser.erase(start_tok.str_index_start(), scope_end + end_str.size());
1242 }
1243 };
1244
1245 parser.foreach_match("#ww", [&](const std::vector<Token> &tokens) {
1246 if (tokens[1].str() == "ifndef" && tokens[2].str() == "GPU_SHADER") {
1247 process_disabled_scope(tokens[0]);
1248 }
1249 });
1250 parser.foreach_match("#i!w(w)", [&](const std::vector<Token> &tokens) {
1251 if (tokens[1].str() == "if" && tokens[3].str() == "defined" &&
1252 tokens[5].str() == "GPU_SHADER")
1253 {
1254 process_disabled_scope(tokens[0]);
1255 }
1256 });
1257 parser.foreach_match("#i0", [&](const std::vector<Token> &tokens) {
1258 if (tokens[1].str() == "if" && tokens[2].str() == "0") {
1259 process_disabled_scope(tokens[0]);
1260 }
1261 });
1262 return parser.result_get();
1263 }
1264
1265 std::string pragmas_mutation(const std::string &str, report_callback &report_error)
1266 {
1267 /* Remove unsupported directives. */
1268 using namespace std;
1269 using namespace shader::parser;
1270
1271 Parser parser(str, report_error);
1272 parser.foreach_match("#ww", [&](const std::vector<Token> &tokens) {
1273 if (tokens[1].str() == "pragma") {
1274 if (tokens[2].str() == "once") {
1275 parser.erase(tokens.front(), tokens.back());
1276 }
1277 else if (tokens[2].str() == "runtime_generated") {
1278 parser.erase(tokens.front(), tokens.back());
1279 }
1280 }
1281 });
1282 return parser.result_get();
1283 }
1284
1285 std::string swizzle_function_mutation(const std::string &str, report_callback &report_error)
1286 {
1287 using namespace std;
1288 using namespace shader::parser;
1289
1290 Parser parser(str, report_error);
1291
1292 parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
1293 /* Change C++ swizzle functions into plain swizzle. */
1295 scope.foreach_match(".w()", [&](const std::vector<Token> &tokens) {
1296 string method_name = tokens[1].str();
1297 if (method_name.length() > 1 && method_name.length() <= 4 &&
1298 (method_name.find_first_not_of("xyzw") == string::npos ||
1299 method_name.find_first_not_of("rgba") == string::npos))
1300 {
1301 /* `.xyz()` -> `.xyz` */
1302 /* Keep character count the same. Replace parenthesis by spaces. */
1303 parser.replace(tokens[2], tokens[3], " ");
1304 }
1305 });
1306 });
1307 return parser.result_get();
1308 }
1309
1310 void threadgroup_variables_parsing(const std::string &str)
1311 {
1312 std::regex regex(R"(shared\s+(\w+)\s+(\w+)([^;]*);)");
1313 regex_global_search(str, regex, [&](const std::smatch &match) {
1314 shared_vars_.push_back({match[1].str(), match[2].str(), match[3].str()});
1315 });
1316 }
1317
1318 void parse_library_functions(const std::string &str)
1319 {
1320 using namespace metadata;
1321 std::regex regex_func(R"(void\s+(\w+)\s*\‍(([^)]+\))\s*\{)");
1322 regex_global_search(str, regex_func, [&](const std::smatch &match) {
1323 std::string name = match[1].str();
1324 std::string args = match[2].str();
1325
1326 FunctionFormat fn;
1327 fn.name = name;
1328
1329 std::regex regex_arg(R"((?:(const|in|out|inout)\s)?(\w+)\s([\w\‍[\‍]]+)(?:,|\)))");
1330 regex_global_search(args, regex_arg, [&](const std::smatch &arg) {
1331 std::string qualifier = arg[1].str();
1332 std::string type = arg[2].str();
1333 if (qualifier.empty() || qualifier == "const") {
1334 qualifier = "in";
1335 }
1336 fn.arguments.emplace_back(
1337 ArgumentFormat{metadata::Qualifier(hash(qualifier)), metadata::Type(hash(type))});
1338 });
1339 metadata.functions.emplace_back(fn);
1340 });
1341 }
1342
1343 void parse_builtins(const std::string &str, const std::string &filename)
1344 {
1345 const bool skip_drw_debug = filename.find("draw_debug_draw_lib.glsl") != std::string::npos ||
1346 filename.find("draw_debug_draw_display_vert.glsl") !=
1347 std::string::npos ||
1348 filename.find("draw_shader_shared.hh") != std::string::npos;
1349 using namespace metadata;
1350 /* TODO: This can trigger false positive caused by disabled #if blocks. */
1351 std::string tokens[] = {"gl_FragCoord",
1352 "gl_FragStencilRefARB",
1353 "gl_FrontFacing",
1354 "gl_GlobalInvocationID",
1355 "gl_InstanceID",
1356 "gl_LocalInvocationID",
1357 "gl_LocalInvocationIndex",
1358 "gl_NumWorkGroup",
1359 "gl_PointCoord",
1360 "gl_PointSize",
1361 "gl_PrimitiveID",
1362 "gl_VertexID",
1363 "gl_WorkGroupID",
1364 "gl_WorkGroupSize",
1365 "drw_debug_",
1366#ifdef WITH_GPU_SHADER_ASSERT
1367 "assert",
1368#endif
1369 "printf"};
1370 for (auto &token : tokens) {
1371 if (skip_drw_debug && token == "drw_debug_") {
1372 continue;
1373 }
1374 if (str.find(token) != std::string::npos) {
1375 metadata.builtins.emplace_back(Builtin(hash(token)));
1376 }
1377 }
1378 }
1379
1380 template<typename ReportErrorF>
1381 std::string printf_processing(const std::string &str, const ReportErrorF &report_error)
1382 {
1383 std::string out_str = str;
1384 {
1385 /* Example: `printf(2, b, f(c, d));` > `printf(2@ b@ f(c@ d))$` */
1386 size_t start, end = 0;
1387 while ((start = out_str.find("printf(", end)) != std::string::npos) {
1388 end = out_str.find(';', start);
1389 if (end == std::string::npos) {
1390 break;
1391 }
1392 out_str[end] = '$';
1393 int bracket_depth = 0;
1394 int arg_len = 0;
1395 for (size_t i = start; i < end; ++i) {
1396 if (out_str[i] == '(') {
1397 bracket_depth++;
1398 }
1399 else if (out_str[i] == ')') {
1400 bracket_depth--;
1401 }
1402 else if (bracket_depth == 1 && out_str[i] == ',') {
1403 out_str[i] = '@';
1404 arg_len++;
1405 }
1406 }
1407 if (arg_len > 99) {
1408 report_error(line_number(out_str, start),
1409 char_number(out_str, start),
1410 line_str(out_str, start),
1411 "Too many parameters in printf. Max is 99.");
1412 break;
1413 }
1414 /* Encode number of arg in the `ntf` of `printf`. */
1415 out_str[start + sizeof("printf") - 4] = '$';
1416 out_str[start + sizeof("printf") - 3] = ((arg_len / 10) > 0) ? ('0' + arg_len / 10) : '$';
1417 out_str[start + sizeof("printf") - 2] = '0' + arg_len % 10;
1418 }
1419 if (end == 0) {
1420 /* No printf in source. */
1421 return str;
1422 }
1423 }
1424 /* Example: `pri$$1(2@ b)$` > `{int c_ = print_header(1, 2); c_ = print_data(c_, b); }` */
1425 {
1426 std::regex regex(R"(pri\$\$?(\d{1,2})\‍()");
1427 out_str = std::regex_replace(out_str, regex, "{uint c_ = print_header($1u, ");
1428 }
1429 {
1430 std::regex regex(R"(\@)");
1431 out_str = std::regex_replace(out_str, regex, "); c_ = print_data(c_,");
1432 }
1433 {
1434 std::regex regex(R"(\$)");
1435 out_str = std::regex_replace(out_str, regex, "; }");
1436 }
1437 return out_str;
1438 }
1439
1440 void assert_processing(Parser &parser, const std::string &filepath, report_callback report_error)
1441 {
1442 std::string filename = std::regex_replace(filepath, std::regex(R"((?:.*)\/(.*))"), "$1");
1443
1444 using namespace std;
1445 using namespace shader::parser;
1446
1447 /* Example: `assert(i < 0)` > `if (!(i < 0)) { printf(...); }` */
1448 parser.foreach_match("w(..)", [&](const vector<Token> &tokens) {
1449 if (tokens[0].str() != "assert") {
1450 return;
1451 }
1452 string replacement;
1453#ifdef WITH_GPU_SHADER_ASSERT
1454 string condition = tokens[1].scope().str();
1455 replacement += "if (!" + condition + ") ";
1456 replacement += "{";
1457 replacement += " printf(\"";
1458 replacement += "Assertion failed: " + condition + ", ";
1459 replacement += "file " + filename + ", ";
1460 replacement += "line %d, ";
1461 replacement += "thread (%u,%u,%u).\\n";
1462 replacement += "\"";
1463 replacement += ", __LINE__, GPU_THREAD.x, GPU_THREAD.y, GPU_THREAD.z); ";
1464 replacement += "}";
1465#endif
1466 parser.replace(tokens[0], tokens[4], replacement);
1467 });
1468#ifndef WITH_GPU_SHADER_ASSERT
1469 (void)filename;
1470 (void)report_error;
1471#endif
1472 parser.apply_mutations();
1473 }
1474
1475 /* String hash are outputted inside GLSL and needs to fit 32 bits. */
1476 static uint32_t hash_string(const std::string &str)
1477 {
1478 uint64_t hash_64 = metadata::hash(str);
1479 uint32_t hash_32 = uint32_t(hash_64 ^ (hash_64 >> 32));
1480 return hash_32;
1481 }
1482
1483 void static_strings_merging(Parser &parser, report_callback /*report_error*/)
1484 {
1485 using namespace std;
1486 using namespace shader::parser;
1487
1488 do {
1489 parser.foreach_match("__", [&](const std::vector<Token> &tokens) {
1490 string first = tokens[0].str();
1491 string second = tokens[1].str();
1492 string between = parser.substr_range_inclusive(
1493 tokens[0].str_index_last_no_whitespace() + 1, tokens[1].str_index_start() - 1);
1494 string trailing = parser.substr_range_inclusive(
1495 tokens[1].str_index_last_no_whitespace() + 1, tokens[1].str_index_last());
1496 string merged = first.substr(0, first.length() - 1) + second.substr(1) + between +
1497 trailing;
1498 parser.replace_try(tokens[0], tokens[1], merged);
1499 });
1500 } while (parser.apply_mutations());
1501 }
1502
1503 void static_strings_parsing_and_mutation(Parser &parser, report_callback /*report_error*/)
1504 {
1505 using namespace std;
1506 using namespace shader::parser;
1507
1508 parser.foreach_token(String, [&](const Token &token) {
1509 uint32_t hash = hash_string(token.str());
1510 metadata::PrintfFormat format = {hash, token.str()};
1511 metadata.printf_formats.emplace_back(format);
1512 parser.replace(token, std::to_string(hash) + 'u', true);
1513 });
1514 parser.apply_mutations();
1515 }
1516
1517 /* Move all method definition outside of struct definition blocks. */
1518 void struct_method_mutation(Parser &parser, report_callback report_error)
1519 {
1520 using namespace std;
1521 using namespace shader::parser;
1522
1523 parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
1524 /* `class` -> `struct` */
1525 scope.foreach_match("S", [&](const std::vector<Token> &tokens) {
1526 parser.replace(tokens[0], tokens[0], "struct ");
1527 });
1528 });
1529
1530 parser.apply_mutations();
1531
1532 parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
1533 scope.foreach_match("sw", [&](const std::vector<Token> &tokens) {
1534 const Token struct_name = tokens[1];
1535
1536 if (struct_name.next() == ':') {
1537 report_error(struct_name.next().line_number(),
1538 struct_name.next().char_number(),
1539 struct_name.next().line_str(),
1540 "class inheritance is not supported");
1541 return;
1542 }
1543 if (struct_name.next() != '{') {
1544 report_error(struct_name.line_number(),
1545 struct_name.char_number(),
1546 struct_name.line_str(),
1547 "Expected `{`");
1548 return;
1549 }
1550
1551 const Scope struct_scope = struct_name.next().scope();
1552 const Token struct_end = struct_scope.end().next();
1553
1554 /* Erase `public:` and `private:` keywords. */
1555 struct_scope.foreach_match("v:", [&](const std::vector<Token> &tokens) {
1556 parser.erase(tokens[0].line_start(), tokens[1].line_end());
1557 });
1558 struct_scope.foreach_match("V:", [&](const std::vector<Token> &tokens) {
1559 parser.erase(tokens[0].line_start(), tokens[1].line_end());
1560 });
1561
1562 struct_scope.foreach_match("ww(", [&](const std::vector<Token> &tokens) {
1563 if (tokens[0].prev() == Const) {
1564 report_error(tokens[0].prev().line_number(),
1565 tokens[0].prev().char_number(),
1566 tokens[0].prev().line_str(),
1567 "function return type is marked `const` but it makes no sense for values "
1568 "and returning reference is not supported");
1569 return;
1570 }
1571
1572 const bool is_static = tokens[0].prev() == Static;
1573 const Token fn_start = is_static ? tokens[0].prev() : tokens[0];
1574 const Scope fn_args = tokens[2].scope();
1575 const Token after_args = fn_args.end().next();
1576 const bool is_const = after_args == Const;
1577 const Scope fn_body = (is_const ? after_args.next() : after_args).scope();
1578
1579 string fn_content = parser.substr_range_inclusive(fn_start.line_start(),
1580 fn_body.end().line_end() + 1);
1581
1582 Parser fn_parser(fn_content, report_error);
1583 fn_parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
1584 if (is_static) {
1585 scope.foreach_match("mww(", [&](const std::vector<Token> &tokens) {
1586 const Token fn_name = tokens[2];
1587 fn_parser.replace(fn_name, fn_name, struct_name.str() + "::" + fn_name.str());
1588 /* WORKAROUND: Erase the static keyword as it conflict with the wrapper class
1589 * member accesses MSL. */
1590 fn_parser.erase(tokens[0]);
1591 });
1592 }
1593 else {
1594 scope.foreach_match("ww(", [&](const std::vector<Token> &tokens) {
1595 const Scope args = tokens[2].scope();
1596 const bool has_no_args = args.token_count() == 2;
1597 const char *suffix = (has_no_args ? "" : ", ");
1598
1599 if (is_const) {
1600 fn_parser.erase(args.end().next());
1601 fn_parser.insert_after(args.start(),
1602 "const " + struct_name.str() + " this_" + suffix);
1603 }
1604 else {
1605 fn_parser.insert_after(args.start(), struct_name.str() + " &this_" + suffix);
1606 }
1607 });
1608 }
1609
1610 /* `*this` -> `this_` */
1611 scope.foreach_match("*T", [&](const std::vector<Token> &tokens) {
1612 fn_parser.replace(tokens[0], tokens[1], "this_");
1613 });
1614 /* `this->` -> `this_.` */
1615 scope.foreach_match("TD", [&](const std::vector<Token> &tokens) {
1616 fn_parser.replace(tokens[0], tokens[1], "this_.");
1617 });
1618 });
1619
1620 string line_directive = "#line " + std::to_string(fn_start.line_number()) + '\n';
1621 parser.erase(fn_start.line_start(), fn_body.end().line_end());
1622 parser.insert_after(struct_end.line_end() + 1, line_directive + fn_parser.result_get());
1623 });
1624
1625 string line_directive = "#line " + std::to_string(struct_end.line_number() + 1) + '\n';
1626 parser.insert_after(struct_end.line_end() + 1, line_directive);
1627 });
1628 });
1629
1630 parser.apply_mutations();
1631 }
1632
1633 /* Add padding member to empty structs.
1634 * Empty structs are useful for templating. */
1635 void empty_struct_mutation(Parser &parser, report_callback /*report_error*/)
1636 {
1637 using namespace std;
1638 using namespace shader::parser;
1639
1640 parser.foreach_scope(ScopeType::Global, [&](Scope scope) {
1641 scope.foreach_match("sw{};", [&](const std::vector<Token> &tokens) {
1642 parser.insert_after(tokens[2], "int _pad;");
1643 });
1644 });
1645 parser.apply_mutations();
1646 }
1647
1648 /* Transform `a.fn(b)` into `fn(a, b)`. */
1649 void method_call_mutation(Parser &parser, report_callback report_error)
1650 {
1651 using namespace std;
1652 using namespace shader::parser;
1653
1654 do {
1655 parser.foreach_scope(ScopeType::Function, [&](Scope scope) {
1656 scope.foreach_match(".w(", [&](const std::vector<Token> &tokens) {
1657 const Token dot = tokens[0];
1658 const Token func = tokens[1];
1659 const Token par_open = tokens[2];
1660 const Token end_of_this = dot.prev();
1661 Token start_of_this = end_of_this;
1662 while (true) {
1663 if (start_of_this == ')') {
1664 /* Function call. Take argument scope and function name. No recursion. */
1665 start_of_this = start_of_this.scope().start().prev();
1666 break;
1667 }
1668 if (start_of_this == ']') {
1669 /* Array subscript. Take scope and continue. */
1670 start_of_this = start_of_this.scope().start().prev();
1671 continue;
1672 }
1673 if (start_of_this == Word) {
1674 /* Member. */
1675 if (start_of_this.prev() == '.') {
1676 start_of_this = start_of_this.prev().prev();
1677 /* Continue until we find root member. */
1678 continue;
1679 }
1680 /* End of chain. */
1681 break;
1682 }
1683 report_error(start_of_this.line_number(),
1684 start_of_this.char_number(),
1685 start_of_this.line_str(),
1686 "method_call_mutation parsing error");
1687 break;
1688 }
1689 string this_str = parser.substr_range_inclusive(start_of_this, end_of_this);
1690 string func_str = func.str();
1691 const bool has_no_arg = par_open.next() == ')';
1692 /* `a.fn(b)` -> `fn(a, b)` */
1693 parser.replace_try(
1694 start_of_this, par_open, func_str + "(" + this_str + (has_no_arg ? "" : ", "));
1695 });
1696 });
1697 } while (parser.apply_mutations());
1698 }
1699
1700 void stage_function_mutation(Parser &parser, report_callback /*report_error*/)
1701 {
1702 using namespace std;
1703 using namespace shader::parser;
1704
1705 parser.foreach_function([&](bool is_static, Token fn_type, Token, Scope, bool, Scope fn_body) {
1706 Token attr_tok = (is_static) ? fn_type.prev().prev() : fn_type.prev();
1707 if (attr_tok.is_invalid() || attr_tok != ']' || attr_tok.prev() != ']') {
1708 return;
1709 }
1710 Scope attribute = attr_tok.prev().scope();
1711 if (attribute.type() != ScopeType::Subscript) {
1712 return;
1713 }
1714
1715 const string attr = attribute.str_exclusive();
1716 parser.erase(attribute.scope());
1717
1718 string condition = "defined(";
1719 if (attr == "gpu::vertex_function") {
1720 condition += "GPU_VERTEX_SHADER";
1721 }
1722 else if (attr == "gpu::fragment_function") {
1723 condition += "GPU_FRAGMENT_SHADER";
1724 }
1725 else if (attr == "gpu::compute_function") {
1726 condition += "GPU_COMPUTE_SHADER";
1727 }
1728 else {
1729 return;
1730 }
1731 condition += ")";
1732
1733 guarded_scope_mutation(parser, fn_body, condition);
1734 });
1735 parser.apply_mutations();
1736 }
1737
1738 void resource_guard_mutation(Parser &parser, report_callback /*report_error*/)
1739 {
1740 using namespace std;
1741 using namespace shader::parser;
1742
1743 parser.foreach_function([&](bool, Token fn_type, Token, Scope, bool, Scope fn_body) {
1744 fn_body.foreach_match("w(w,", [&](const std::vector<Token> &tokens) {
1745 string func_name = tokens[0].str();
1746 if (func_name != "specialization_constant_get" && func_name != "shared_variable_get" &&
1747 func_name != "push_constant_get" && func_name != "interface_get" &&
1748 func_name != "attribute_get" && func_name != "buffer_get" &&
1749 func_name != "sampler_get" && func_name != "image_get")
1750 {
1751 return;
1752 }
1753 string info_name = tokens[2].str();
1754 Scope scope = tokens[0].scope();
1755 /* We can be in expression scope. Take parent scope until we find a local scope. */
1756 while (scope.type() != ScopeType::Function && scope.type() != ScopeType::Local) {
1757 scope = scope.scope();
1758 }
1759
1760 string condition = "defined(CREATE_INFO_" + info_name + ")";
1761
1762 if (scope.type() == ScopeType::Function) {
1763 guarded_scope_mutation(parser, scope, condition, fn_type);
1764 }
1765 else {
1766 guarded_scope_mutation(parser, scope, condition);
1767 }
1768 });
1769 });
1770
1771 parser.apply_mutations();
1772 }
1773
1774 void guarded_scope_mutation(parser::Parser &parser,
1775 parser::Scope scope,
1776 const std::string &condition,
1777 parser::Token fn_type = parser::Token::invalid())
1778 {
1779 using namespace std;
1780 using namespace shader::parser;
1781
1782 string line_start = "#line " + std::to_string(scope.start().next().line_number()) + "\n";
1783 string line_end = "#line " + std::to_string(scope.end().line_number()) + "\n";
1784
1785 string guard_start = "#if " + condition + "\n";
1786 string guard_else;
1787 if (fn_type.is_valid() && fn_type.str() != "void") {
1788 string type = fn_type.str();
1789 bool is_trivial = false;
1790 if (type == "float" || type == "float2" || type == "float3" || type == "float4" ||
1791 /**/
1792 type == "int" || type == "int2" || type == "int3" || type == "int4" ||
1793 /**/
1794 type == "uint" || type == "uint2" || type == "uint3" || type == "uint4" ||
1795 /**/
1796 type == "float2x2" || type == "float2x3" || type == "float2x4" ||
1797 /**/
1798 type == "float3x2" || type == "float3x3" || type == "float3x4" ||
1799 /**/
1800 type == "float4x2" || type == "float4x3" || type == "float4x4")
1801 {
1802 is_trivial = true;
1803 }
1804 guard_else += "#else\n";
1805 guard_else += line_start;
1806 guard_else += " return " + type + (is_trivial ? "(0)" : "::zero()") + ";\n";
1807 }
1808 string guard_end = "#endif\n";
1809
1810 parser.insert_after(scope.start().line_end() + 1, guard_start + line_start);
1811 parser.insert_before(scope.end().line_start(), guard_else + guard_end + line_end);
1812 };
1813
1814 std::string guarded_scope_mutation(std::string content, int64_t line_start, std::string check)
1815 {
1816 int64_t line_end = line_start + line_count(content);
1817 std::string guarded_cope;
1818 guarded_cope += "#if " + check + "\n";
1819 guarded_cope += "#line " + std::to_string(line_start) + "\n";
1820 guarded_cope += content;
1821 guarded_cope += "#endif\n";
1822 guarded_cope += "#line " + std::to_string(line_end) + "\n";
1823 return guarded_cope;
1824 }
1825
1826 std::string enum_macro_injection(const std::string &str,
1827 bool is_shared_file,
1828 report_callback &report_error)
1829 {
1859 using namespace std;
1860 using namespace shader::parser;
1861
1862 Parser parser(str, report_error);
1863
1864 auto missing_underlying_type = [&](vector<Token> tokens) {
1865 report_error(tokens[0].line_number(),
1866 tokens[0].char_number(),
1867 tokens[0].line_str(),
1868 "enum declaration must explicitly use an underlying type");
1869 };
1870
1871 parser.foreach_match("Mw{", missing_underlying_type);
1872 parser.foreach_match("MSw{", missing_underlying_type);
1873
1874 auto process_enum =
1875 [&](Token enum_tok, Token class_tok, Token enum_name, Token enum_type, Scope enum_scope) {
1876 string type_str = enum_type.str();
1877
1878 if (is_shared_file) {
1879 if (type_str != "uint32_t" && type_str != "int32_t") {
1880 report_error(
1881 enum_type.line_number(),
1882 enum_type.char_number(),
1883 enum_type.line_str(),
1884 "enum declaration must use uint32_t or int32_t underlying type for interface "
1885 "compatibility");
1886 return;
1887 }
1888 }
1889
1890 size_t insert_at = enum_scope.end().line_end();
1891 parser.erase(enum_tok.str_index_start(), insert_at);
1892 parser.insert_line_number(insert_at + 1, enum_tok.line_number());
1893 parser.insert_after(insert_at + 1,
1894 "#define " + enum_name.str() + " " + enum_type.str() + "\n");
1895
1896 enum_scope.foreach_scope(ScopeType::Assignment, [&](Scope scope) {
1897 string name = scope.start().prev().str();
1898 string value = scope.str();
1899 if (class_tok.is_valid()) {
1900 name = enum_name.str() + "::" + name;
1901 }
1902 string decl = "constant static constexpr " + type_str + " " + name + " " + value +
1903 ";\n";
1904 parser.insert_line_number(insert_at + 1, scope.start().line_number());
1905 parser.insert_after(insert_at + 1, decl);
1906 });
1907 parser.insert_line_number(insert_at + 1, enum_scope.end().line_number() + 1);
1908 };
1909
1910 parser.foreach_match("MSw:w{", [&](vector<Token> tokens) {
1911 process_enum(tokens[0], tokens[1], tokens[2], tokens[4], tokens[5].scope());
1912 });
1913 parser.foreach_match("Mw:w{", [&](vector<Token> tokens) {
1914 process_enum(tokens[0], Token::invalid(), tokens[1], tokens[3], tokens[4].scope());
1915 });
1916
1917 parser.apply_mutations();
1918
1919 parser.foreach_match("M", [&](vector<Token> tokens) {
1920 report_error(tokens[0].line_number(),
1921 tokens[0].char_number(),
1922 tokens[0].line_str(),
1923 "invalid enum declaration");
1924 });
1925 return parser.result_get();
1926 }
1927
1928 std::string strip_whitespace(const std::string &str) const
1929 {
1930 return str.substr(0, str.find_last_not_of(" \n") + 1);
1931 }
1932
1937 void default_argument_mutation(Parser &parser, report_callback /*report_error*/)
1938 {
1939 using namespace std;
1940 using namespace shader::parser;
1941
1942 parser.foreach_function(
1943 [&](bool, Token fn_type, Token fn_name, Scope fn_args, bool, Scope fn_body) {
1944 if (!fn_args.contains_token('=')) {
1945 return;
1946 }
1947
1948 const bool has_non_void_return_type = fn_type.str() != "void";
1949
1950 string args_decl;
1951 string args_names;
1952
1953 vector<string> fn_overloads;
1954
1955 fn_args.foreach_scope(ScopeType::FunctionArg, [&](Scope arg) {
1956 Token equal = arg.find_token('=');
1957 const char *comma = (args_decl.empty() ? "" : ", ");
1958 if (equal.is_invalid()) {
1959 args_decl += comma + arg.str();
1960 args_names += comma + arg.end().str();
1961 }
1962 else {
1963 string arg_name = equal.prev().str();
1964 string value = parser.substr_range_inclusive(equal.next(), arg.end());
1965 string decl = parser.substr_range_inclusive(arg.start(), equal.prev());
1966
1967 string fn_call = fn_name.str() + '(' + args_names + comma + value + ");";
1968 if (has_non_void_return_type) {
1969 fn_call = "return " + fn_call;
1970 }
1971 string overload;
1972 overload += fn_type.str() + " ";
1973 overload += fn_name.str() + '(' + args_decl + ")\n";
1974 overload += "{\n";
1975 overload += "#line " + std::to_string(fn_type.line_number()) + "\n";
1976 overload += " " + fn_call + "\n}\n";
1977 fn_overloads.emplace_back(overload);
1978
1979 args_decl += comma + strip_whitespace(decl);
1980 args_names += comma + arg_name;
1981 /* Erase the value assignment and keep the declaration. */
1982 parser.erase(equal.scope());
1983 }
1984 });
1985 size_t end_of_fn_char = fn_body.end().line_end() + 1;
1986 /* Have to reverse the declaration order. */
1987 for (auto it = fn_overloads.rbegin(); it != fn_overloads.rend(); ++it) {
1988 parser.insert_line_number(end_of_fn_char, fn_type.line_number());
1989 parser.insert_after(end_of_fn_char, *it);
1990 }
1991 parser.insert_line_number(end_of_fn_char, fn_body.end().line_number() + 1);
1992 });
1993
1994 parser.apply_mutations();
1995 }
1996
1997 /* Used to make GLSL matrix constructor compatible with MSL in pyGPU shaders.
1998 * This syntax is not supported in blender's own shaders. */
1999 std::string matrix_constructor_mutation(const std::string &str)
2000 {
2001 if (str.find("mat") == std::string::npos) {
2002 return str;
2003 }
2004 /* Example: `mat2(x)` > `mat2x2(x)` */
2005 std::regex regex_parenthesis(R"(\bmat([234])\‍()");
2006 std::string out = std::regex_replace(str, regex_parenthesis, "mat$1x$1(");
2007 /* Only process square matrices since this is the only types we overload the constructors. */
2008 /* Example: `mat2x2(x)` > `__mat2x2(x)` */
2009 std::regex regex(R"(\bmat(2x2|3x3|4x4)\‍()");
2010 return std::regex_replace(out, regex, "__mat$1(");
2011 }
2012
2013 /* To be run before `argument_decorator_macro_injection()`. */
2014 void argument_reference_mutation(Parser &parser, report_callback /*report_error*/)
2015 {
2016 using namespace std;
2017 using namespace shader::parser;
2018
2019 auto add_mutation = [&](Token type, Token arg_name, Token last_tok) {
2020 if (type.prev() == Const) {
2021 parser.replace(type.prev(), last_tok, type.str() + " " + arg_name.str());
2022 }
2023 else {
2024 parser.replace(type, last_tok, "inout " + type.str() + " " + arg_name.str());
2025 }
2026 };
2027
2028 parser.foreach_scope(ScopeType::FunctionArgs, [&](const Scope scope) {
2029 scope.foreach_match(
2030 "w(&w)", [&](const vector<Token> toks) { add_mutation(toks[0], toks[3], toks[4]); });
2031 scope.foreach_match(
2032 "w&w", [&](const vector<Token> toks) { add_mutation(toks[0], toks[2], toks[2]); });
2033 scope.foreach_match(
2034 "w&T", [&](const vector<Token> toks) { add_mutation(toks[0], toks[2], toks[2]); });
2035 });
2036 parser.apply_mutations();
2037 }
2038
2039 /* To be run after `argument_reference_mutation()`. */
2040 std::string variable_reference_mutation(const std::string &str, report_callback report_error)
2041 {
2042 using namespace std;
2043 /* Processing regex and logic is expensive. Check if they are needed at all. */
2044 bool valid_match = false;
2045 string next_str = str;
2046 reference_search(next_str, [&](int parenthesis_depth, int /*bracket_depth*/, char &c) {
2047 /* Check if inside a function body. */
2048 if (parenthesis_depth == 0) {
2049 valid_match = true;
2050 /* Modify the & into @ to make sure we only match these references in the regex
2051 * below. @ being forbidden in the shader language, it is safe to use a temp
2052 * character. */
2053 c = '@';
2054 }
2055 });
2056 if (!valid_match) {
2057 return str;
2058 }
2059 string out_str;
2060 /* Example: `const float &var = value;` */
2061 regex regex_ref(R"(\ ?(?:const)?\s*\w+\s+\@(\w+) =\s*([^;]+);)");
2062
2063 smatch match;
2064 while (regex_search(next_str, match, regex_ref)) {
2065 const string definition = match[0].str();
2066 const string name = match[1].str();
2067 const string value = match[2].str();
2068 const string prefix = match.prefix().str();
2069 const string suffix = match.suffix().str();
2070
2071 out_str += prefix;
2072
2073 /* Assert definition doesn't contain any side effect. */
2074 if (value.find("++") != string::npos || value.find("--") != string::npos) {
2075 report_error(line_number(match),
2076 char_number(match),
2077 line_str(match),
2078 "Reference definitions cannot have side effects.");
2079 return str;
2080 }
2081 if (value.find("(") != string::npos) {
2082 if (value.find("specialization_constant_get(") == string::npos &&
2083 value.find("push_constant_get(") == string::npos &&
2084 value.find("interface_get(") == string::npos &&
2085 value.find("attribute_get(") == string::npos &&
2086 value.find("buffer_get(") == string::npos &&
2087 value.find("sampler_get(") == string::npos && value.find("image_get(") == string::npos)
2088 {
2089 report_error(line_number(match),
2090 char_number(match),
2091 line_str(match),
2092 "Reference definitions cannot contain function calls.");
2093 return str;
2094 }
2095 }
2096 if (value.find("[") != string::npos) {
2097 const string index_var = get_content_between_balanced_pair(value, '[', ']');
2098
2099 if (index_var.find(' ') != string::npos) {
2100 report_error(line_number(match),
2101 char_number(match),
2102 line_str(match),
2103 "Array subscript inside reference declaration must be a single variable or "
2104 "a constant, not an expression.");
2105 return str;
2106 }
2107
2108 /* Add a space to avoid empty scope breaking the loop. */
2109 string scope_depth = " }";
2110 bool found_var = false;
2111 while (!found_var) {
2112 string scope = get_content_between_balanced_pair(out_str + scope_depth, '{', '}', true);
2113 scope_depth += '}';
2114
2115 if (scope.empty()) {
2116 break;
2117 }
2118 /* Remove nested scopes. Avoid variable shadowing to mess with the detection. */
2119 scope = regex_replace(scope, regex(R"(\{[^\}]*\})"), "{}");
2120 /* Search if index variable definition qualifies it as `const`. */
2121 regex regex_definition(R"((const)? \w+ )" + index_var + " =");
2122 smatch match_definition;
2123 if (regex_search(scope, match_definition, regex_definition)) {
2124 found_var = true;
2125 if (match_definition[1].matched == false) {
2126 report_error(line_number(match),
2127 char_number(match),
2128 line_str(match),
2129 "Array subscript variable must be declared as const qualified.");
2130 return str;
2131 }
2132 }
2133 }
2134 if (!found_var) {
2135 report_error(line_number(match),
2136 char_number(match),
2137 line_str(match),
2138 "Cannot locate array subscript variable declaration. "
2139 "If it is a global variable, assign it to a temporary const variable for "
2140 "indexing inside the reference.");
2141 return str;
2142 }
2143 }
2144
2145 /* Find scope this definition is active in. */
2146 const string scope = get_content_between_balanced_pair('{' + suffix, '{', '}');
2147 if (scope.empty()) {
2148 report_error(line_number(match),
2149 char_number(match),
2150 line_str(match),
2151 "Reference is defined inside a global or unterminated scope.");
2152 return str;
2153 }
2154 string original = definition + scope;
2155 string modified = original;
2156
2157 /* Replace definition by nothing. Keep number of lines. */
2158 string newlines(line_count(definition), '\n');
2159 replace_all(modified, definition, newlines);
2160 /* Replace every occurrence of the reference. Avoid matching other symbols like class members
2161 * and functions with the same name. */
2162 modified = regex_replace(
2163 modified, regex(R"(([^\.])\b)" + name + R"(\b([^(]))"), "$1" + value + "$2");
2164
2166 next_str = definition + suffix;
2167
2168 /* Replace whole modified scope in output string. */
2169 replace_all(next_str, original, modified);
2170 }
2171 out_str += next_str;
2172 return out_str;
2173 }
2174
2175 std::string argument_decorator_macro_injection(const std::string &str)
2176 {
2177 /* Example: `out float var[2]` > `out float _out_sta var _out_end[2]` */
2178 std::regex regex(R"((out|inout|in|shared)\s+(\w+)\s+(\w+))");
2179 return std::regex_replace(str, regex, "$1 $2 _$1_sta $3 _$1_end");
2180 }
2181
2182 std::string array_constructor_macro_injection(const std::string &str)
2183 {
2184 /* Example: `= float[2](0.0, 0.0)` > `= ARRAY_T(float) ARRAY_V(0.0, 0.0)` */
2185 std::regex regex(R"(=\s*(\w+)\s*\‍[[^\‍]]*\‍]\s*\‍()");
2186 return std::regex_replace(str, regex, "= ARRAY_T($1) ARRAY_V(");
2187 }
2188
2189 /* Assume formatted source with our code style. Cannot be applied to python shaders. */
2190 void global_scope_constant_linting(Parser &parser, report_callback report_error)
2191 {
2192 using namespace std;
2193 using namespace shader::parser;
2194
2195 /* Example: `const uint global_var = 1u;`. */
2196 parser.foreach_match("cww=", [&](const vector<Token> &tokens) {
2197 if (tokens[0].scope().type() == ScopeType::Global) {
2198 report_error(
2199 ERROR_TOK(tokens[2]),
2200 "Global scope constant expression found. These get allocated per-thread in MSL. "
2201 "Use Macro's or uniforms instead.");
2202 }
2203 });
2204 }
2205
2206 void quote_linting(const std::string &str, report_callback report_error)
2207 {
2208 using namespace std;
2209 using namespace shader::parser;
2210
2211 Parser parser(str, report_error);
2212 /* This only catches some invalid usage. For the rest, the CI will catch them. */
2213 parser.foreach_token(TokenType::String, [&](const Token token) {
2214 report_error(ERROR_TOK(token),
2215 "Unprocessed string literal. "
2216 "Strings are forbidden in GLSL.");
2217 });
2218 }
2219
2220 void small_type_linting(Parser &parser, report_callback report_error)
2221 {
2222 using namespace std;
2223 using namespace shader::parser;
2224
2225 parser.foreach_scope(ScopeType::Struct, [&](const Scope scope) {
2226 scope.foreach_match("ww;", [&](const vector<Token> tokens) {
2227 string type = tokens[0].str();
2228 if (type.find("char") != string::npos || type.find("short") != string::npos ||
2229 type.find("half") != string::npos)
2230 {
2231 report_error(ERROR_TOK(tokens[0]), "Small types are forbidden in shader interfaces.");
2232 }
2233 });
2234 });
2235 }
2236
2237 std::string threadgroup_variables_suffix()
2238 {
2239 if (shared_vars_.empty()) {
2240 return "";
2241 }
2242
2243 std::stringstream suffix;
2257 suffix << "\n";
2258 /* Arguments of the wrapper class constructor. */
2259 suffix << "#undef MSL_SHARED_VARS_ARGS\n";
2260 /* References assignment inside wrapper class constructor. */
2261 suffix << "#undef MSL_SHARED_VARS_ASSIGN\n";
2262 /* Declaration of threadgroup variables in entry point function. */
2263 suffix << "#undef MSL_SHARED_VARS_DECLARE\n";
2264 /* Arguments for wrapper class constructor call. */
2265 suffix << "#undef MSL_SHARED_VARS_PASS\n";
2266
2305 std::stringstream args, assign, declare, pass;
2306
2307 bool first = true;
2308 for (SharedVar &var : shared_vars_) {
2309 char sep = first ? ' ' : ',';
2310
2311 args << sep << "threadgroup " << var.type << "(&_" << var.name << ")" << var.array;
2312 assign << (first ? ':' : ',') << var.name << "(_" << var.name << ")";
2313 declare << "threadgroup " << var.type << ' ' << var.name << var.array << ";";
2314 pass << sep << var.name;
2315 first = false;
2316 }
2317
2318 suffix << "#define MSL_SHARED_VARS_ARGS " << args.str() << "\n";
2319 suffix << "#define MSL_SHARED_VARS_ASSIGN " << assign.str() << "\n";
2320 suffix << "#define MSL_SHARED_VARS_DECLARE " << declare.str() << "\n";
2321 suffix << "#define MSL_SHARED_VARS_PASS (" << pass.str() << ")\n";
2322 suffix << "\n";
2323
2324 return suffix.str();
2325 }
2326
2327 std::string line_directive_prefix(const std::string &filepath)
2328 {
2329 std::string filename = std::regex_replace(filepath, std::regex(R"((?:.*)\/(.*))"), "$1");
2330
2331 std::stringstream suffix;
2332 /* NOTE: This is not supported by GLSL. All line directives are muted at runtime and the
2333 * sources are scanned after error reporting for the locating the muted line. */
2334 suffix << "#line 1 \"" << filename << "\"\n";
2335 return suffix.str();
2336 }
2337
2338 /* Made public for unit testing purpose. */
2339 public:
2340 static std::string get_content_between_balanced_pair(const std::string &input,
2341 char start_delimiter,
2342 char end_delimiter,
2343 const bool backwards = false)
2344 {
2345 int balance = 0;
2346 size_t start = std::string::npos;
2347 size_t end = std::string::npos;
2348
2349 if (backwards) {
2350 std::swap(start_delimiter, end_delimiter);
2351 }
2352
2353 for (size_t i = 0; i < input.length(); ++i) {
2354 size_t idx = backwards ? (input.length() - 1) - i : i;
2355 if (input[idx] == start_delimiter) {
2356 if (balance == 0) {
2357 start = idx;
2358 }
2359 balance++;
2360 }
2361 else if (input[idx] == end_delimiter) {
2362 balance--;
2363 if (balance == 0 && start != std::string::npos) {
2364 end = idx;
2365 if (backwards) {
2366 std::swap(start, end);
2367 }
2368 return input.substr(start + 1, end - start - 1);
2369 }
2370 }
2371 }
2372 return "";
2373 }
2374
2375 /* Replaces all occurrences of `from` by `to` between `start_delimiter`
2376 * and `end_delimiter` even inside nested delimiters pair. */
2377 static std::string replace_char_between_balanced_pair(const std::string &input,
2378 const char start_delimiter,
2379 const char end_delimiter,
2380 const char from,
2381 const char to)
2382 {
2383 int depth = 0;
2384
2385 std::string str = input;
2386 for (char &string_char : str) {
2387 if (string_char == start_delimiter) {
2388 depth++;
2389 }
2390 else if (string_char == end_delimiter) {
2391 depth--;
2392 }
2393 else if (depth > 0 && string_char == from) {
2394 string_char = to;
2395 }
2396 }
2397 return str;
2398 }
2399
2400 /* Function to split a string by a delimiter and return a vector of substrings. */
2401 static std::vector<std::string> split_string(const std::string &str, const char delimiter)
2402 {
2403 std::vector<std::string> substrings;
2404 std::stringstream ss(str);
2405 std::string item;
2406
2407 while (std::getline(ss, item, delimiter)) {
2408 substrings.push_back(item);
2409 }
2410 return substrings;
2411 }
2412
2413 /* Similar to split_string but only split if the delimiter is not between any pair_start and
2414 * pair_end. */
2415 static std::vector<std::string> split_string_not_between_balanced_pair(const std::string &str,
2416 const char delimiter,
2417 const char pair_start,
2418 const char pair_end)
2419 {
2420 const char safe_char = '@';
2421 const std::string safe_str = replace_char_between_balanced_pair(
2422 str, pair_start, pair_end, delimiter, safe_char);
2423 std::vector<std::string> split = split_string(safe_str, delimiter);
2424 for (std::string &str : split) {
2425 replace_all(str, safe_char, delimiter);
2426 }
2427 return split;
2428 }
2429
2430 static void replace_all(std::string &str, const std::string &from, const std::string &to)
2431 {
2432 if (from.empty()) {
2433 return;
2434 }
2435 size_t start_pos = 0;
2436 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
2437 str.replace(start_pos, from.length(), to);
2438 start_pos += to.length();
2439 }
2440 }
2441
2442 static void replace_all(std::string &str, const char from, const char to)
2443 {
2444 for (char &string_char : str) {
2445 if (string_char == from) {
2446 string_char = to;
2447 }
2448 }
2449 }
2450
2451 static int64_t char_count(const std::string &str, char c)
2452 {
2453 return std::count(str.begin(), str.end(), c);
2454 }
2455
2456 static int64_t line_count(const std::string &str)
2457 {
2458 return char_count(str, '\n');
2459 }
2460
2461 /* Match any reference definition (e.g. `int &a = b`).
2462 * Call the callback function for each `&` character that matches a reference definition.
2463 * Expects the input `str` to be formatted with balanced parenthesis and curly brackets. */
2464 static void reference_search(std::string &str, std::function<void(int, int, char &)> callback)
2465 {
2467 str, '&', [&](size_t pos, int parenthesis_depth, int bracket_depth, char &c) {
2468 if (pos > 0 && pos <= str.length() - 2) {
2469 /* This is made safe by the previous check. */
2470 char prev_char = str[pos - 1];
2471 char next_char = str[pos + 1];
2472 /* Validate it is not an operator (`&`, `&&`, `&=`). */
2473 if (prev_char == ' ' || prev_char == '(') {
2474 if (next_char != ' ' && next_char != '\n' && next_char != '&' && next_char != '=') {
2475 callback(parenthesis_depth, bracket_depth, c);
2476 }
2477 }
2478 }
2479 });
2480 }
2481
2482 /* Match any default argument definition (e.g. `void func(int a = 0)`).
2483 * Call the callback function for each `=` character inside a function argument list.
2484 * Expects the input `str` to be formatted with balanced parenthesis and curly brackets. */
2485 static void default_argument_search(std::string &str,
2486 std::function<void(int, int, char &)> callback)
2487 {
2489 str, '=', [&](size_t pos, int parenthesis_depth, int bracket_depth, char &c) {
2490 if (pos > 0 && pos <= str.length() - 2) {
2491 /* This is made safe by the previous check. */
2492 char prev_char = str[pos - 1];
2493 char next_char = str[pos + 1];
2494 /* Validate it is not an operator (`==`, `<=`, `>=`). Expects formatted input. */
2495 if (prev_char == ' ' && next_char == ' ') {
2496 if (parenthesis_depth == 1 && bracket_depth == 0) {
2497 callback(parenthesis_depth, bracket_depth, c);
2498 }
2499 }
2500 }
2501 });
2502 }
2503
2504 /* Scan through a string matching for every occurrence of a character.
2505 * Calls the callback with the context in which the match occurs. */
2506 static void scopes_scan_for_char(std::string &str,
2507 char search_char,
2508 std::function<void(size_t, int, int, char &)> callback)
2509 {
2510 size_t pos = 0;
2511 int parenthesis_depth = 0;
2512 int bracket_depth = 0;
2513 for (char &c : str) {
2514 if (c == search_char) {
2515 callback(pos, parenthesis_depth, bracket_depth, c);
2516 }
2517 else if (c == '(') {
2518 parenthesis_depth++;
2519 }
2520 else if (c == ')') {
2521 parenthesis_depth--;
2522 }
2523 else if (c == '{') {
2524 bracket_depth++;
2525 }
2526 else if (c == '}') {
2527 bracket_depth--;
2528 }
2529 pos++;
2530 }
2531 }
2532
2533 /* Return the line number this token is found at. Take into account the #line directives. */
2534 static size_t line_number(const std::string &file_str, size_t pos)
2535 {
2536 std::string sub_str = file_str.substr(0, pos);
2537 std::string directive = "#line ";
2538 size_t nearest_line_directive = sub_str.rfind(directive);
2539 size_t line_count = 1;
2540 if (nearest_line_directive != std::string::npos) {
2541 sub_str = sub_str.substr(nearest_line_directive + directive.size());
2542 line_count = std::stoll(sub_str) - 1;
2543 }
2544 return line_count + std::count(sub_str.begin(), sub_str.end(), '\n');
2545 }
2546 static size_t line_number(const std::smatch &smatch)
2547 {
2548 std::string whole_file = smatch.prefix().str() + smatch[0].str() + smatch.suffix().str();
2549 return line_number(whole_file, smatch.prefix().str().size());
2550 }
2551
2552 /* Return the offset to the start of the line. */
2553 static size_t char_number(const std::string &file_str, size_t pos)
2554 {
2555 std::string sub_str = file_str.substr(0, pos);
2556 size_t nearest_line_directive = sub_str.find_last_of("\n");
2557 return (nearest_line_directive == std::string::npos) ?
2558 (sub_str.size() - 1) :
2559 (sub_str.size() - nearest_line_directive);
2560 }
2561 static size_t char_number(const std::smatch &smatch)
2562 {
2563 std::string whole_file = smatch.prefix().str() + smatch[0].str() + smatch.suffix().str();
2564 return char_number(whole_file, smatch.prefix().str().size());
2565 }
2566
2567 /* Return the line the token is at. */
2568 static std::string line_str(const std::string &file_str, size_t pos)
2569 {
2570 size_t start = file_str.rfind('\n', pos);
2571 size_t end = file_str.find('\n', pos);
2572 if (start == std::string::npos) {
2573 start = 0;
2574 }
2575 return file_str.substr(start, end - start);
2576 }
2577 static std::string line_str(const std::smatch &smatch)
2578 {
2579 std::string whole_file = smatch.prefix().str() + smatch[0].str() + smatch.suffix().str();
2580 return line_str(whole_file, smatch.prefix().str().size());
2581 }
2582};
2583
2584} // namespace blender::gpu::shader
void BLI_kdtree_nd_ balance(KDTree *tree) ATTR_NONNULL(1)
static void split(const char *text, const char *seps, char ***str, int *count)
void init()
long long int int64_t
unsigned long long int uint64_t
static void replace_all(std::string &str, const char from, const char to)
std::string process(SourceLanguage language, std::string str, const std::string &filename, bool do_parse_function, bool do_small_type_linting, report_callback report_error, metadata::Source &r_metadata)
static std::string get_content_between_balanced_pair(const std::string &input, char start_delimiter, char end_delimiter, const bool backwards=false)
std::string process(const std::string &str)
static void reference_search(std::string &str, std::function< void(int, int, char &)> callback)
static size_t line_number(const std::smatch &smatch)
static void default_argument_search(std::string &str, std::function< void(int, int, char &)> callback)
static size_t char_number(const std::string &file_str, size_t pos)
static std::string line_str(const std::string &file_str, size_t pos)
static std::string replace_char_between_balanced_pair(const std::string &input, const char start_delimiter, const char end_delimiter, const char from, const char to)
static void scopes_scan_for_char(std::string &str, char search_char, std::function< void(size_t, int, int, char &)> callback)
static std::string line_str(const std::smatch &smatch)
static SourceLanguage language_from_filename(const std::string &filename)
static int64_t char_count(const std::string &str, char c)
static int64_t line_count(const std::string &str)
static void replace_all(std::string &str, const std::string &from, const std::string &to)
static std::vector< std::string > split_string_not_between_balanced_pair(const std::string &str, const char delimiter, const char pair_start, const char pair_end)
static size_t char_number(const std::smatch &smatch)
static std::vector< std::string > split_string(const std::string &str, const char delimiter)
static size_t line_number(const std::string &file_str, size_t pos)
dot(value.rgb, luminance_coefficients)") DEFINE_VALUE("REDUCE(lhs
#define str(s)
#define ERROR_TOK(token)
uint pos
#define input
#define out
VecBase< bool, D > equal(VecOp< T, D >, VecOp< T, D >) RET
if(state< num_states)
format
static void error(const char *str)
static constexpr uint64_t hash(const char *name)
SymEdge< T > * prev(const SymEdge< T > *se)
#define hash
Definition noise_c.cc:154
static string line_directive(const SourceReplaceState &state, const string &path, const size_t line_number)
Definition path.cpp:797
const char * name
std::string serialize(const std::string &function_name) const
std::vector< FunctionFormat > functions
std::vector< std::string > dependencies
std::vector< PrintfFormat > printf_formats
i
Definition text_draw.cc:230