MLIR 22.0.0git
ByteCode.cpp
Go to the documentation of this file.
1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
17#include "mlir/IR/BuiltinOps.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/Format.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/InterleavedRange.h"
27#include <numeric>
28#include <optional>
29
30#define DEBUG_TYPE "pdl-bytecode"
31
32using namespace mlir;
33using namespace mlir::detail;
34
35//===----------------------------------------------------------------------===//
36// PDLByteCodePattern
37//===----------------------------------------------------------------------===//
38
39PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
42 PatternBenefit benefit = matchOp.getBenefit();
43 MLIRContext *ctx = matchOp.getContext();
44
45 // Collect the set of generated operations.
46 SmallVector<StringRef, 8> generatedOps;
47 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
48 generatedOps =
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50
51 // Check to see if this is pattern matches a specific operation type.
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
53 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
54 generatedOps);
55 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
56 benefit, ctx, generatedOps);
57}
58
59//===----------------------------------------------------------------------===//
60// PDLByteCodeMutableState
61//===----------------------------------------------------------------------===//
62
63/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
64/// to the position of the pattern within the range returned by
65/// `PDLByteCode::getPatterns`.
66void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
67 PatternBenefit benefit) {
68 currentPatternBenefits[patternIndex] = benefit;
69}
70
71/// Cleanup any allocated state after a full match/rewrite has been completed.
72/// This method should be called irregardless of whether the match+rewrite was a
73/// success or not.
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
77}
78
79//===----------------------------------------------------------------------===//
80// Bytecode OpCodes
81//===----------------------------------------------------------------------===//
82
83namespace {
84enum OpCode : ByteCodeField {
85 /// Apply an externally registered constraint.
86 ApplyConstraint,
87 /// Apply an externally registered rewrite.
88 ApplyRewrite,
89 /// Check if two generic values are equal.
90 AreEqual,
91 /// Check if two ranges are equal.
92 AreRangesEqual,
93 /// Unconditional branch.
94 Branch,
95 /// Compare the operand count of an operation with a constant.
96 CheckOperandCount,
97 /// Compare the name of an operation with a constant.
98 CheckOperationName,
99 /// Compare the result count of an operation with a constant.
100 CheckResultCount,
101 /// Compare a range of types to a constant range of types.
102 CheckTypes,
103 /// Continue to the next iteration of a loop.
104 Continue,
105 /// Create a type range from a list of constant types.
106 CreateConstantTypeRange,
107 /// Create an operation.
108 CreateOperation,
109 /// Create a type range from a list of dynamic types.
110 CreateDynamicTypeRange,
111 /// Create a value range.
112 CreateDynamicValueRange,
113 /// Erase an operation.
114 EraseOp,
115 /// Extract the op from a range at the specified index.
116 ExtractOp,
117 /// Extract the type from a range at the specified index.
118 ExtractType,
119 /// Extract the value from a range at the specified index.
120 ExtractValue,
121 /// Terminate a matcher or rewrite sequence.
122 Finalize,
123 /// Iterate over a range of values.
124 ForEach,
125 /// Get a specific attribute of an operation.
126 GetAttribute,
127 /// Get the type of an attribute.
128 GetAttributeType,
129 /// Get the defining operation of a value.
130 GetDefiningOp,
131 /// Get a specific operand of an operation.
132 GetOperand0,
133 GetOperand1,
134 GetOperand2,
135 GetOperand3,
136 GetOperandN,
137 /// Get a specific operand group of an operation.
138 GetOperands,
139 /// Get a specific result of an operation.
140 GetResult0,
141 GetResult1,
142 GetResult2,
143 GetResult3,
144 GetResultN,
145 /// Get a specific result group of an operation.
146 GetResults,
147 /// Get the users of a value or a range of values.
148 GetUsers,
149 /// Get the type of a value.
150 GetValueType,
151 /// Get the types of a value range.
152 GetValueRangeTypes,
153 /// Check if a generic value is not null.
154 IsNotNull,
155 /// Record a successful pattern match.
156 RecordMatch,
157 /// Replace an operation.
158 ReplaceOp,
159 /// Compare an attribute with a set of constants.
160 SwitchAttribute,
161 /// Compare the operand count of an operation with a set of constants.
162 SwitchOperandCount,
163 /// Compare the name of an operation with a set of constants.
164 SwitchOperationName,
165 /// Compare the result count of an operation with a set of constants.
166 SwitchResultCount,
167 /// Compare a type with a set of constants.
168 SwitchType,
169 /// Compare a range of types with a set of constants.
170 SwitchTypes,
171};
172} // namespace
173
174/// A marker used to indicate if an operation should infer types.
175static constexpr ByteCodeField kInferTypesMarker =
176 std::numeric_limits<ByteCodeField>::max();
177
178//===----------------------------------------------------------------------===//
179// ByteCode Generation
180//===----------------------------------------------------------------------===//
181
182//===----------------------------------------------------------------------===//
183// Generator
184//===----------------------------------------------------------------------===//
185
186namespace {
187struct ByteCodeLiveRange;
188struct ByteCodeWriter;
189
190/// Check if the given class `T` can be converted to an opaque pointer.
191template <typename T, typename... Args>
192using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
193
194/// This class represents the main generator for the pattern bytecode.
195class Generator {
196public:
197 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
198 SmallVectorImpl<ByteCodeField> &matcherByteCode,
199 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
210 rewriterByteCode(rewriterByteCode), patterns(patterns),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
216 for (const auto &it : llvm::enumerate(constraintFns))
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
218 for (const auto &it : llvm::enumerate(rewriteFns))
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 }
221
222 /// Generate the bytecode for the given PDL interpreter module.
223 void generate(ModuleOp module);
224
225 /// Return the memory index to use for the given value.
226 ByteCodeField &getMemIndex(Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
230 }
231
232 /// Return the range memory index used to store the given range value.
233 ByteCodeField &getRangeStorageIndex(Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
237 }
238
239 /// Return an index to use when referring to the given data that is uniqued in
240 /// the MLIR context.
241 template <typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
243 getMemIndex(T val) {
244 const void *opaqueVal = val.getAsOpaquePointer();
245
246 // Get or insert a reference to this value.
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
249 if (it.second)
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
252 }
253
254private:
255 /// Allocate memory indices for the results of operations within the matcher
256 /// and rewriters.
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
259
260 /// Generate the bytecode for the given operation.
261 void generate(Region *region, ByteCodeWriter &writer);
262 void generate(Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
301
302 /// Mapping from value to its corresponding memory index.
303 DenseMap<Value, ByteCodeField> valueToMemIndex;
304
305 /// Mapping from a range value to its corresponding range storage index.
306 DenseMap<Value, ByteCodeField> valueToRangeIndex;
307
308 /// Mapping from the name of an externally registered rewrite to its index in
309 /// the bytecode registry.
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311
312 /// Mapping from the name of an externally registered constraint to its index
313 /// in the bytecode registry.
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315
316 /// Mapping from rewriter function name to the bytecode address of the
317 /// rewriter function in byte.
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
319
320 /// Mapping from a uniqued storage object to its memory index within
321 /// `uniquedData`.
322 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
323
324 /// The current level of the foreach loop.
325 ByteCodeField curLoopLevel = 0;
326
327 /// The current MLIR context.
328 MLIRContext *ctx;
329
330 /// Mapping from block to its address.
332
333 /// Data of the ByteCode class to be populated.
334 std::vector<const void *> &uniquedData;
335 SmallVectorImpl<ByteCodeField> &matcherByteCode;
336 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
337 SmallVectorImpl<PDLByteCodePattern> &patterns;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
343
344 /// A map of pattern configurations.
346};
347
348/// This class provides utilities for writing a bytecode stream.
349struct ByteCodeWriter {
350 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
351 : bytecode(bytecode), generator(generator) {}
352
353 /// Append a field to the bytecode.
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
356
357 /// Append an address to the bytecode.
358 void append(ByteCodeAddr field) {
359 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
361
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
365 }
366
367 /// Append a single successor to the bytecode, the exact address will need to
368 /// be resolved later.
369 void append(Block *successor) {
370 // Add back a reference to the successor so that the address can be resolved
371 // later.
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
374 }
375
376 /// Append a successor range to the bytecode, the exact address will need to
377 /// be resolved later.
378 void append(SuccessorRange successors) {
379 for (Block *successor : successors)
380 append(successor);
381 }
382
383 /// Append a range of values that will be read as generic PDLValues.
384 void appendPDLValueList(OperandRange values) {
385 bytecode.push_back(values.size());
386 for (Value value : values)
387 appendPDLValue(value);
388 }
389
390 /// Append a value as a PDLValue.
391 void appendPDLValue(Value value) {
392 appendPDLValueKind(value);
393 append(value);
394 }
395
396 /// Append the PDLValue::Kind of the given value.
397 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
398
399 /// Append the PDLValue::Kind of the given type.
400 void appendPDLValueKind(Type type) {
401 PDLValue::Kind kind =
403 .Case<pdl::AttributeType>(
404 [](Type) { return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](Type) { return PDLValue::Kind::Operation; })
407 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
411 })
412 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
414 bytecode.push_back(static_cast<ByteCodeField>(kind));
415 }
416
417 /// Append a value that will be stored in a memory slot and not inline within
418 /// the bytecode.
419 template <typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
422 append(T value) {
423 bytecode.push_back(generator.getMemIndex(value));
424 }
425
426 /// Append a range of values.
427 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
429 append(T range) {
430 bytecode.push_back(llvm::size(range));
431 for (auto it : range)
432 append(it);
433 }
434
435 /// Append a variadic number of fields to the bytecode.
436 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
438 append(field);
439 append(field2, fields...);
440 }
441
442 /// Appends a value as a pointer, stored inline within the bytecode.
443 template <typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer, sizeof(const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
451 }
452
453 /// Successor references in the bytecode that have yet to be resolved.
454 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
455
456 /// The underlying bytecode buffer.
457 SmallVectorImpl<ByteCodeField> &bytecode;
458
459 /// The main generator producing PDL.
460 Generator &generator;
461};
462
463/// This class represents a live range of PDL Interpreter values, containing
464/// information about when values are live within a match/rewrite.
465struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
468
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470
471 /// Union this live range with the one provided.
472 void unionWith(const ByteCodeLiveRange &rhs) {
473 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
474 ++it)
475 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
476 }
477
478 /// Returns true if this range overlaps with the one provided.
479 bool overlaps(const ByteCodeLiveRange &rhs) const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
481 .valid();
482 }
483
484 /// A map representing the ranges of the match/rewrite that a value is live in
485 /// the interpreter.
486 ///
487 /// We use std::unique_ptr here, because IntervalMap does not provide a
488 /// correct copy or move constructor. We can eliminate the pointer once
489 /// https://reviews.llvm.org/D113240 lands.
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491
492 /// The operation range storage index for this range.
493 std::optional<unsigned> opRangeIndex;
494
495 /// The type range storage index for this range.
496 std::optional<unsigned> typeRangeIndex;
497
498 /// The value range storage index for this range.
499 std::optional<unsigned> valueRangeIndex;
500};
501} // namespace
502
503void Generator::generate(ModuleOp module) {
504 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
509
510 // Allocate memory indices for the results of operations within the matcher
511 // and rewriters.
512 allocateMemoryIndices(matcherFunc, rewriterModule);
513
514 // Generate code for the rewriter functions.
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
516 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
520 }
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
523
524 // Generate code for the matcher function.
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527
528 // Resolve successor references in the matcher.
529 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
533 }
534}
535
536void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
538 // Rewriters use simplistic allocation scheme that simply assigns an index to
539 // each result.
540 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
550 }
551 };
552
553 for (BlockArgument arg : rewriterFunc.getArguments())
554 processRewriterValue(arg);
555 rewriterFunc.getBody().walk([&](Operation *op) {
556 for (Value result : op->getResults())
557 processRewriterValue(result);
558 });
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
565 }
566
567 // The matcher function uses a more sophisticated numbering that tries to
568 // minimize the number of memory indices assigned. This is done by determining
569 // a live range of the values within the matcher, then the allocation is just
570 // finding the minimal number of overlapping live ranges. This is essentially
571 // a simplified form of register allocation where we don't necessarily have a
572 // limited number of registers, but we still want to minimize the number used.
573 DenseMap<Operation *, unsigned> opToFirstIndex;
575
576 // A custom walk that marks the first and the last index of each operation.
577 // The entry marks the beginning of the liveness range for this operation,
578 // followed by nested operations, followed by the end of the liveness range.
579 unsigned index = 0;
580 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
581 opToFirstIndex.try_emplace(op, index++);
582 for (Region &region : op->getRegions())
583 for (Block &block : region.getBlocks())
584 for (Operation &nested : block)
585 walk(&nested);
586 opToLastIndex.try_emplace(op, index++);
587 };
588 walk(matcherFunc);
589
590 // Liveness info for each of the defs within the matcher.
591 ByteCodeLiveRange::Allocator allocator;
593
594 // Assign the root operation being matched to slot 0.
595 BlockArgument rootOpArg = matcherFunc.getArgument(0);
596 valueToMemIndex[rootOpArg] = 0;
597
598 // Walk each of the blocks, computing the def interval that the value is used.
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](Block *block) {
601 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602 assert(info && "expected liveness info for block");
603 auto processValue = [&](Value value, Operation *firstUseOrDef) {
604 // We don't need to process the root op argument, this value is always
605 // assigned to the first memory slot.
606 if (value == rootOpArg)
607 return;
608
609 // Set indices for the range of this block that the value is used.
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
613 opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
614 /*dummyValue*/ 0);
615
616 // Check to see if this value is a range type.
617 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
625 }
626 };
627
628 // Process the live-ins of this block.
629 for (Value liveIn : info->in()) {
630 // Only process the value if it has been defined in the current region.
631 // Other values that span across pdl_interp.foreach will be added higher
632 // up. This ensures that the we keep them alive for the entire duration
633 // of the loop.
634 if (liveIn.getParentRegion() == block->getParent())
635 processValue(liveIn, &block->front());
636 }
637
638 // Process the block arguments for the entry block (those are not live-in).
639 if (block->isEntryBlock()) {
640 for (Value argument : block->getArguments())
641 processValue(argument, &block->front());
642 }
643
644 // Process any new defs within this block.
645 for (Operation &op : *block)
646 for (Value result : op.getResults())
647 processValue(result, &op);
648 });
649
650 // Greedily allocate memory slots using the computed def live ranges.
651 std::vector<ByteCodeLiveRange> allocatedIndices;
652
653 // The number of memory indices currently allocated (and its next value).
654 // Recall that the root gets allocated memory index 0.
655 ByteCodeField numIndices = 1;
656
657 // The number of memory ranges of various types (and their next values).
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
659
660 for (auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
663
664 // Try to allocate to an existing index.
665 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
670
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 } else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 } else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
683 }
684 break;
685 }
686 }
687
688 // If no existing index could be used, add a new one.
689 if (memIndex == 0) {
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
693
694 // Allocate an index for op/type/value ranges.
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 } else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 } else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
704 }
705
706 memIndex = allocatedIndices.size();
707 ++numIndices;
708 }
709 }
710
711 // Print the index usage and ensure that we did not run out of index space.
712 LDBG() << "Allocated " << allocatedIndices.size() << " indices "
713 << "(down from initial " << valueDefRanges.size() << ").";
714 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715 "Ran out of memory for allocated indices");
716
717 // Update the max number of indices.
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
726}
727
728void Generator::generate(Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (Block *block : rpot) {
731 // Keep track of where this block begins within the matcher function.
732 blockToAddr.try_emplace(block, matcherByteCode.size());
733 for (Operation &op : *block)
734 generate(&op, writer);
735 }
736}
737
738void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739 LDBG() << "Generating bytecode for operation: " << op->getName();
740 LLVM_DEBUG({
741 // The following list must contain all the operations that do not
742 // produce any bytecode.
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->getLoc());
745 });
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](auto interpOp) { this->generate(interpOp, writer); })
767 .DefaultUnreachable("unknown `pdl_interp` operation");
768}
769
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 // Constraints that should return a value have to be registered as rewrites.
773 // If a constraint and a rewrite of similar name are registered the
774 // constraint takes precedence
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
778 ResultRange results = op.getResults();
779 writer.append(ByteCodeField(results.size()));
780 for (Value result : results) {
781 // We record the expected kind of the result, so that we can provide extra
782 // verification of the native rewrite function and handle the failure case
783 // of constraints accordingly.
784 writer.appendPDLValueKind(result);
785
786 // Range results also need to append the range storage index.
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
790 }
791 writer.append(op.getSuccessors());
792}
793void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
799
800 ResultRange results = op.getResults();
801 writer.append(ByteCodeField(results.size()));
802 for (Value result : results) {
803 // We record the expected kind of the result, so that we
804 // can provide extra verification of the native rewrite function.
805 writer.appendPDLValueKind(result);
806
807 // Range results also need to append the range storage index.
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
811 }
812}
813void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819 return;
820 }
821
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823}
824void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826}
827void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830 op.getSuccessors());
831}
832void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField>(op.getCompareAtLeast()),
836 op.getSuccessors());
837}
838void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
841 OperationName(op.getName(), ctx), op.getSuccessors());
842}
843void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField>(op.getCompareAtLeast()),
847 op.getSuccessors());
848}
849void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851 op.getSuccessors());
852}
853void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855 op.getSuccessors());
856}
857void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860}
861void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
863 // Simply repoint the memory index of the result to the constant.
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865}
866void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
869 OperationName(op.getName(), ctx));
870 writer.appendPDLValueList(op.getInputOperands());
871
872 // Add the attributes.
873 OperandRange attributes = op.getInputAttributes();
874 writer.append(static_cast<ByteCodeField>(attributes.size()));
875 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
877
878 // Add the result types. If the operation has inferred results, we use a
879 // marker "size" value. Otherwise, we add the list of explicit result types.
880 if (op.getInferredResultTypes())
881 writer.append(kInferTypesMarker);
882 else
883 writer.appendPDLValueList(op.getInputResultTypes());
884}
885void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886 // Append the correct opcode for the range type.
887 TypeSwitch<Type>(op.getType().getElementType())
888 .Case(
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
892 });
893
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
896}
897void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898 // Simply repoint the memory index of the result to the constant.
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900}
901void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
904}
905void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
907}
908void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909 OpCode opCode =
910 TypeSwitch<Type, OpCode>(op.getResult().getType())
911 .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914 .DefaultUnreachable("unsupported element type");
915 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
916}
917void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918 writer.append(OpCode::Finalize);
919}
920void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
921 BlockArgument arg = op.getLoopVariable();
922 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923 writer.appendPDLValueKind(arg.getType());
924 writer.append(curLoopLevel, op.getSuccessor());
925 ++curLoopLevel;
926 if (curLoopLevel > maxLoopLevel)
927 maxLoopLevel = curLoopLevel;
928 generate(&op.getRegion(), writer);
929 --curLoopLevel;
930}
931void Generator::generate(pdl_interp::GetAttributeOp op,
932 ByteCodeWriter &writer) {
933 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
934 op.getNameAttr());
935}
936void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937 ByteCodeWriter &writer) {
938 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
939}
940void Generator::generate(pdl_interp::GetDefiningOpOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetDefiningOp, op.getInputOp());
943 writer.appendPDLValue(op.getValue());
944}
945void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946 uint32_t index = op.getIndex();
947 if (index < 4)
948 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
949 else
950 writer.append(OpCode::GetOperandN, index);
951 writer.append(op.getInputOp(), op.getValue());
952}
953void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954 Value result = op.getValue();
955 std::optional<uint32_t> index = op.getIndex();
956 writer.append(OpCode::GetOperands,
957 index.value_or(std::numeric_limits<uint32_t>::max()),
958 op.getInputOp());
959 if (isa<pdl::RangeType>(result.getType()))
960 writer.append(getRangeStorageIndex(result));
961 else
962 writer.append(std::numeric_limits<ByteCodeField>::max());
963 writer.append(result);
964}
965void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966 uint32_t index = op.getIndex();
967 if (index < 4)
968 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
969 else
970 writer.append(OpCode::GetResultN, index);
971 writer.append(op.getInputOp(), op.getValue());
972}
973void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974 Value result = op.getValue();
975 std::optional<uint32_t> index = op.getIndex();
976 writer.append(OpCode::GetResults,
977 index.value_or(std::numeric_limits<uint32_t>::max()),
978 op.getInputOp());
979 if (isa<pdl::RangeType>(result.getType()))
980 writer.append(getRangeStorageIndex(result));
981 else
982 writer.append(std::numeric_limits<ByteCodeField>::max());
983 writer.append(result);
984}
985void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986 Value operations = op.getOperations();
987 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988 writer.append(OpCode::GetUsers, operations, rangeIndex);
989 writer.appendPDLValue(op.getValue());
990}
991void Generator::generate(pdl_interp::GetValueTypeOp op,
992 ByteCodeWriter &writer) {
993 if (isa<pdl::RangeType>(op.getType())) {
994 Value result = op.getResult();
995 writer.append(OpCode::GetValueRangeTypes, result,
996 getRangeStorageIndex(result), op.getValue());
997 } else {
998 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
999 }
1000}
1001void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1003}
1004void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005 ByteCodeField patternIndex = patterns.size();
1006 patterns.emplace_back(PDLByteCodePattern::create(
1007 op, configMap.lookup(op),
1008 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009 writer.append(OpCode::RecordMatch, patternIndex,
1010 SuccessorRange(op.getOperation()), op.getMatchedOps());
1011 writer.appendPDLValueList(op.getInputs());
1012}
1013void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014 writer.append(OpCode::ReplaceOp, op.getInputOp());
1015 writer.appendPDLValueList(op.getReplValues());
1016}
1017void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018 ByteCodeWriter &writer) {
1019 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020 op.getCaseValuesAttr(), op.getSuccessors());
1021}
1022void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023 ByteCodeWriter &writer) {
1024 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025 op.getCaseValuesAttr(), op.getSuccessors());
1026}
1027void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028 ByteCodeWriter &writer) {
1029 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1030 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1031 });
1032 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033 op.getSuccessors());
1034}
1035void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036 ByteCodeWriter &writer) {
1037 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038 op.getCaseValuesAttr(), op.getSuccessors());
1039}
1040void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042 op.getSuccessors());
1043}
1044void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// PDLByteCode
1051//===----------------------------------------------------------------------===//
1052
1053PDLByteCode::PDLByteCode(
1054 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056 llvm::StringMap<PDLConstraintFunction> constraintFns,
1057 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058 : configs(std::move(configs)) {
1059 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1060 rewriterByteCode, patterns, maxValueMemoryIndex,
1061 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062 maxLoopLevel, constraintFns, rewriteFns, configMap);
1063 generator.generate(module);
1064
1065 // Initialize the external functions.
1066 for (auto &it : constraintFns)
1067 constraintFunctions.push_back(std::move(it.second));
1068 for (auto &it : rewriteFns)
1069 rewriteFunctions.push_back(std::move(it.second));
1070}
1071
1072/// Initialize the given state such that it can be used to execute the current
1073/// bytecode.
1075 state.memory.resize(maxValueMemoryIndex, nullptr);
1076 state.opRangeMemory.resize(maxOpRangeCount);
1077 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1078 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1079 state.loopIndex.resize(maxLoopLevel, 0);
1080 state.currentPatternBenefits.reserve(patterns.size());
1081 for (const PDLByteCodePattern &pattern : patterns)
1082 state.currentPatternBenefits.push_back(pattern.getBenefit());
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// ByteCode Execution
1087//===----------------------------------------------------------------------===//
1088
1089namespace {
1090/// This class is an instantiation of the PDLResultList that provides access to
1091/// the returned results. This API is not on `PDLResultList` to avoid
1092/// overexposing access to information specific solely to the ByteCode.
1093class ByteCodeRewriteResultList : public PDLResultList {
1094public:
1095 ByteCodeRewriteResultList(unsigned maxNumResults)
1096 : PDLResultList(maxNumResults) {}
1097
1098 /// Return the list of PDL results.
1099 MutableArrayRef<PDLValue> getResults() { return results; }
1100
1101 /// Return the type ranges allocated by this list.
1102 MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() {
1103 return allocatedTypeRanges;
1104 }
1105
1106 /// Return the value ranges allocated by this list.
1107 MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() {
1108 return allocatedValueRanges;
1109 }
1110};
1111
1112/// This class provides support for executing a bytecode stream.
1113class ByteCodeExecutor {
1114public:
1115 ByteCodeExecutor(const ByteCodeField *curCodeIt,
1116 MutableArrayRef<const void *> memory,
1117 MutableArrayRef<std::vector<Operation *>> opRangeMemory,
1118 MutableArrayRef<TypeRange> typeRangeMemory,
1119 std::vector<std::vector<Type>> &allocatedTypeRangeMemory,
1120 MutableArrayRef<ValueRange> valueRangeMemory,
1121 std::vector<std::vector<Value>> &allocatedValueRangeMemory,
1122 MutableArrayRef<unsigned> loopIndex,
1123 ArrayRef<const void *> uniquedMemory,
1124 ArrayRef<ByteCodeField> code,
1125 ArrayRef<PatternBenefit> currentPatternBenefits,
1126 ArrayRef<PDLByteCodePattern> patterns,
1127 ArrayRef<PDLConstraintFunction> constraintFunctions,
1128 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130 typeRangeMemory(typeRangeMemory),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132 valueRangeMemory(valueRangeMemory),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory),
1134 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1136 constraintFunctions(constraintFunctions),
1137 rewriteFunctions(rewriteFunctions) {}
1138
1139 /// Start executing the code at the current bytecode index. `matches` is an
1140 /// optional field provided when this function is executed in a matching
1141 /// context.
1142 LogicalResult
1143 execute(PatternRewriter &rewriter,
1144 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1145 std::optional<Location> mainRewriteLoc = {});
1146
1147private:
1148 /// Internal implementation of executing each of the bytecode commands.
1149 void executeApplyConstraint(PatternRewriter &rewriter);
1150 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1160 void executeCreateOperation(PatternRewriter &rewriter,
1161 Location mainRewriteLoc);
1162 template <typename T>
1163 void executeDynamicCreateRange(StringRef type);
1164 void executeEraseOp(PatternRewriter &rewriter);
1165 template <typename T, typename Range, PDLValue::Kind kind>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(unsigned index);
1173 void executeGetOperands();
1174 void executeGetResult(unsigned index);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1180 void executeRecordMatch(PatternRewriter &rewriter,
1181 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182 void executeReplaceOp(PatternRewriter &rewriter);
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList &results,
1190 unsigned numResults,
1191 LogicalResult &rewriteResult);
1192
1193 /// Pushes a code iterator to the stack.
1194 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1195
1196 /// Pops a code iterator from the stack, returning true on success.
1197 void popCodeIt() {
1198 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1199 curCodeIt = resumeCodeIt.pop_back_val();
1200 }
1201
1202 /// Return the bytecode iterator at the start of the current op code.
1203 const ByteCodeField *getPrevCodeIt() const {
1204 LLVM_DEBUG({
1205 // Account for the op code and the Location stored inline.
1206 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1207 });
1208
1209 // Account for the op code only.
1210 return curCodeIt - 1;
1211 }
1212
1213 /// Read a value from the bytecode buffer, optionally skipping a certain
1214 /// number of prefix values. These methods always update the buffer to point
1215 /// to the next field after the read data.
1216 template <typename T = ByteCodeField>
1217 T read(size_t skipN = 0) {
1218 curCodeIt += skipN;
1219 return readImpl<T>();
1220 }
1221 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1222
1223 /// Read a list of values from the bytecode buffer.
1224 template <typename ValueT, typename T>
1225 void readList(SmallVectorImpl<T> &list) {
1226 list.clear();
1227 for (unsigned i = 0, e = read(); i != e; ++i)
1228 list.push_back(read<ValueT>());
1229 }
1230
1231 /// Read a list of values from the bytecode buffer. The values may be encoded
1232 /// either as a single element or a range of elements.
1233 void readList(SmallVectorImpl<Type> &list) {
1234 for (unsigned i = 0, e = read(); i != e; ++i) {
1235 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1236 list.push_back(read<Type>());
1237 } else {
1238 TypeRange *values = read<TypeRange *>();
1239 list.append(values->begin(), values->end());
1240 }
1241 }
1242 }
1243 void readList(SmallVectorImpl<Value> &list) {
1244 for (unsigned i = 0, e = read(); i != e; ++i) {
1245 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1246 list.push_back(read<Value>());
1247 } else {
1248 ValueRange *values = read<ValueRange *>();
1249 list.append(values->begin(), values->end());
1250 }
1251 }
1252 }
1253
1254 /// Read a value stored inline as a pointer.
1255 template <typename T>
1256 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1257 readInline() {
1258 const void *pointer;
1259 std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1260 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1261 return T::getFromOpaquePointer(pointer);
1262 }
1263
1264 void skip(size_t skipN) { curCodeIt += skipN; }
1265
1266 /// Jump to a specific successor based on a predicate value.
1267 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1268 /// Jump to a specific successor based on a destination index.
1269 void selectJump(size_t destIndex) {
1270 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1271 }
1272
1273 /// Handle a switch operation with the provided value and cases.
1274 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1275 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1276 LDBG() << "Switch operation:\n * Value: " << value
1277 << "\n * Cases: " << llvm::interleaved(cases);
1278
1279 // Check to see if the attribute value is within the case list. Jump to
1280 // the correct successor index based on the result.
1281 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1282 if (cmp(*it, value))
1283 return selectJump(size_t((it - cases.begin()) + 1));
1284 selectJump(size_t(0));
1285 }
1286
1287 /// Store a pointer to memory.
1288 void storeToMemory(unsigned index, const void *value) {
1289 memory[index] = value;
1290 }
1291
1292 /// Store a value to memory as an opaque pointer.
1293 template <typename T>
1294 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1295 storeToMemory(unsigned index, T value) {
1296 memory[index] = value.getAsOpaquePointer();
1297 }
1298
1299 /// Internal implementation of reading various data types from the bytecode
1300 /// stream.
1301 template <typename T>
1302 const void *readFromMemory() {
1303 size_t index = *curCodeIt++;
1304
1305 // If this type is an SSA value, it can only be stored in non-const memory.
1306 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1307 Value>::value ||
1308 index < memory.size())
1309 return memory[index];
1310
1311 // Otherwise, if this index is not inbounds it is uniqued.
1312 return uniquedMemory[index - memory.size()];
1313 }
1314 template <typename T>
1315 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1316 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1317 }
1318 template <typename T>
1319 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1320 T>
1321 readImpl() {
1322 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1323 }
1324 template <typename T>
1325 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1326 switch (read<PDLValue::Kind>()) {
1327 case PDLValue::Kind::Attribute:
1328 return read<Attribute>();
1329 case PDLValue::Kind::Operation:
1330 return read<Operation *>();
1331 case PDLValue::Kind::Type:
1332 return read<Type>();
1333 case PDLValue::Kind::Value:
1334 return read<Value>();
1335 case PDLValue::Kind::TypeRange:
1336 return read<TypeRange *>();
1337 case PDLValue::Kind::ValueRange:
1338 return read<ValueRange *>();
1339 }
1340 llvm_unreachable("unhandled PDLValue::Kind");
1341 }
1342 template <typename T>
1343 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1344 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1345 "unexpected ByteCode address size");
1346 ByteCodeAddr result;
1347 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1348 curCodeIt += 2;
1349 return result;
1350 }
1351 template <typename T>
1352 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1353 return *curCodeIt++;
1354 }
1355 template <typename T>
1356 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1357 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1358 }
1359
1360 /// Assign the given range to the given memory index. This allocates a new
1361 /// range object if necessary.
1362 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1363 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1364 unsigned rangeIndex) {
1365 // Utility functor used to type-erase the assignment.
1366 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1367 // If the input range is empty, we don't need to allocate anything.
1368 if (range.empty()) {
1369 rangeMemory[rangeIndex] = {};
1370 } else {
1371 // Assign this to the range slot and use the range as the value for the
1372 // memory index.
1373 allocatedRangeMemory.emplace_back(range.begin(), range.end());
1374 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1375 }
1376 memory[memIndex] = &rangeMemory[rangeIndex];
1377 };
1378
1379 // Dispatch based on the concrete range type.
1380 if constexpr (std::is_same_v<T, Type>) {
1381 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1382 } else if constexpr (std::is_same_v<T, Value>) {
1383 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1384 } else {
1385 llvm_unreachable("unhandled range type");
1386 }
1387 }
1388
1389 /// The underlying bytecode buffer.
1390 const ByteCodeField *curCodeIt;
1391
1392 /// The stack of bytecode positions at which to resume operation.
1393 SmallVector<const ByteCodeField *> resumeCodeIt;
1394
1395 /// The current execution memory.
1396 MutableArrayRef<const void *> memory;
1397 MutableArrayRef<std::vector<Operation *>> opRangeMemory;
1398 MutableArrayRef<TypeRange> typeRangeMemory;
1399 std::vector<std::vector<Type>> &allocatedTypeRangeMemory;
1400 MutableArrayRef<ValueRange> valueRangeMemory;
1401 std::vector<std::vector<Value>> &allocatedValueRangeMemory;
1402
1403 /// The current loop indices.
1404 MutableArrayRef<unsigned> loopIndex;
1405
1406 /// References to ByteCode data necessary for execution.
1407 ArrayRef<const void *> uniquedMemory;
1408 ArrayRef<ByteCodeField> code;
1409 ArrayRef<PatternBenefit> currentPatternBenefits;
1410 ArrayRef<PDLByteCodePattern> patterns;
1411 ArrayRef<PDLConstraintFunction> constraintFunctions;
1412 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1413};
1414} // namespace
1415
1416void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1417 LDBG() << "Executing ApplyConstraint:";
1418 ByteCodeField fun_idx = read();
1419 SmallVector<PDLValue, 16> args;
1420 readList<PDLValue>(args);
1421
1422 LDBG() << " * Arguments: " << llvm::interleaved(args);
1423
1424 ByteCodeField isNegated = read();
1425 LDBG() << " * isNegated: " << isNegated;
1426
1427 ByteCodeField numResults = read();
1428 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1429 ByteCodeRewriteResultList results(numResults);
1430 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1431 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1432 if (succeeded(rewriteResult)) {
1433 LDBG() << " * Constraint succeeded, results: "
1434 << llvm::interleaved(constraintResults);
1435 } else {
1436 LDBG() << " * Constraint failed";
1437 }
1438 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1439 "native PDL rewrite function succeeded but returned "
1440 "unexpected number of results");
1441 processNativeFunResults(results, numResults, rewriteResult);
1442
1443 // Depending on the constraint jump to the proper destination.
1444 selectJump(isNegated != succeeded(rewriteResult));
1445}
1446
1447LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1448 LDBG() << "Executing ApplyRewrite:";
1449 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1450 SmallVector<PDLValue, 16> args;
1451 readList<PDLValue>(args);
1452
1453 LDBG() << " * Arguments: " << llvm::interleaved(args);
1454
1455 // Execute the rewrite function.
1456 ByteCodeField numResults = read();
1457 ByteCodeRewriteResultList results(numResults);
1458 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1459
1460 assert(results.getResults().size() == numResults &&
1461 "native PDL rewrite function returned unexpected number of results");
1462
1463 processNativeFunResults(results, numResults, rewriteResult);
1464
1465 if (failed(rewriteResult)) {
1466 LDBG() << " - Failed";
1467 return failure();
1468 }
1469 return success();
1470}
1471
1472void ByteCodeExecutor::processNativeFunResults(
1473 ByteCodeRewriteResultList &results, unsigned numResults,
1474 LogicalResult &rewriteResult) {
1475 if (failed(rewriteResult)) {
1476 // Skip the according number of values on the buffer on failure and exit
1477 // early as there are no results to process.
1478 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1479 const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1480 if (resultKind == PDLValue::Kind::TypeRange ||
1481 resultKind == PDLValue::Kind::ValueRange) {
1482 skip(2);
1483 } else {
1484 skip(1);
1485 }
1486 }
1487 return;
1488 }
1489
1490 // Store the results in the bytecode memory
1491 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1492 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1493 (void)resultKind;
1494 PDLValue result = results.getResults()[resultIdx];
1495 LDBG() << " * Result: " << result;
1496 assert(result.getKind() == resultKind &&
1497 "native PDL rewrite function returned an unexpected type of "
1498 "result");
1499 // If the result is a range, we need to copy it over to the bytecodes
1500 // range memory.
1501 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1502 unsigned rangeIndex = read();
1503 typeRangeMemory[rangeIndex] = *typeRange;
1504 memory[read()] = &typeRangeMemory[rangeIndex];
1505 } else if (std::optional<ValueRange> valueRange =
1506 result.dyn_cast<ValueRange>()) {
1507 unsigned rangeIndex = read();
1508 valueRangeMemory[rangeIndex] = *valueRange;
1509 memory[read()] = &valueRangeMemory[rangeIndex];
1510 } else {
1511 memory[read()] = result.getAsOpaquePointer();
1512 }
1513 }
1514
1515 // Copy over any underlying storage allocated for result ranges.
1516 for (auto &it : results.getAllocatedTypeRanges())
1517 allocatedTypeRangeMemory.push_back(std::move(it));
1518 for (auto &it : results.getAllocatedValueRanges())
1519 allocatedValueRangeMemory.push_back(std::move(it));
1520}
1521
1522void ByteCodeExecutor::executeAreEqual() {
1523 LDBG() << "Executing AreEqual:";
1524 const void *lhs = read<const void *>();
1525 const void *rhs = read<const void *>();
1526
1527 LDBG() << " * " << lhs << " == " << rhs;
1528 selectJump(lhs == rhs);
1529}
1530
1531void ByteCodeExecutor::executeAreRangesEqual() {
1532 LDBG() << "Executing AreRangesEqual:";
1533 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1534 const void *lhs = read<const void *>();
1535 const void *rhs = read<const void *>();
1536
1537 switch (valueKind) {
1538 case PDLValue::Kind::TypeRange: {
1539 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1540 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1541 LDBG() << " * " << lhs << " == " << rhs;
1542 selectJump(*lhsRange == *rhsRange);
1543 break;
1544 }
1545 case PDLValue::Kind::ValueRange: {
1546 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1547 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1548 LDBG() << " * " << lhs << " == " << rhs;
1549 selectJump(*lhsRange == *rhsRange);
1550 break;
1551 }
1552 default:
1553 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1554 }
1555}
1556
1557void ByteCodeExecutor::executeBranch() {
1558 LDBG() << "Executing Branch";
1559 curCodeIt = &code[read<ByteCodeAddr>()];
1560}
1561
1562void ByteCodeExecutor::executeCheckOperandCount() {
1563 LDBG() << "Executing CheckOperandCount:";
1564 Operation *op = read<Operation *>();
1565 uint32_t expectedCount = read<uint32_t>();
1566 bool compareAtLeast = read();
1567
1568 LDBG() << " * Found: " << op->getNumOperands()
1569 << "\n * Expected: " << expectedCount
1570 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1571 if (compareAtLeast)
1572 selectJump(op->getNumOperands() >= expectedCount);
1573 else
1574 selectJump(op->getNumOperands() == expectedCount);
1575}
1576
1577void ByteCodeExecutor::executeCheckOperationName() {
1578 LDBG() << "Executing CheckOperationName:";
1579 Operation *op = read<Operation *>();
1580 OperationName expectedName = read<OperationName>();
1581
1582 LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
1583 << expectedName << "\"";
1584 selectJump(op->getName() == expectedName);
1585}
1586
1587void ByteCodeExecutor::executeCheckResultCount() {
1588 LDBG() << "Executing CheckResultCount:";
1589 Operation *op = read<Operation *>();
1590 uint32_t expectedCount = read<uint32_t>();
1591 bool compareAtLeast = read();
1592
1593 LDBG() << " * Found: " << op->getNumResults()
1594 << "\n * Expected: " << expectedCount
1595 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1596 if (compareAtLeast)
1597 selectJump(op->getNumResults() >= expectedCount);
1598 else
1599 selectJump(op->getNumResults() == expectedCount);
1600}
1601
1602void ByteCodeExecutor::executeCheckTypes() {
1603 LDBG() << "Executing AreEqual:";
1604 TypeRange *lhs = read<TypeRange *>();
1605 Attribute rhs = read<Attribute>();
1606 LDBG() << " * " << lhs << " == " << rhs;
1607
1608 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1609}
1610
1611void ByteCodeExecutor::executeContinue() {
1612 ByteCodeField level = read();
1613 LDBG() << "Executing Continue\n * Level: " << level;
1614 ++loopIndex[level];
1615 popCodeIt();
1616}
1617
1618void ByteCodeExecutor::executeCreateConstantTypeRange() {
1619 LDBG() << "Executing CreateConstantTypeRange:";
1620 unsigned memIndex = read();
1621 unsigned rangeIndex = read();
1622 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1623
1624 LDBG() << " * Types: " << typesAttr;
1625 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1626 rangeIndex);
1627}
1628
1629void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1630 Location mainRewriteLoc) {
1631 LDBG() << "Executing CreateOperation:";
1632
1633 unsigned memIndex = read();
1634 OperationState state(mainRewriteLoc, read<OperationName>());
1635 readList(state.operands);
1636 for (unsigned i = 0, e = read(); i != e; ++i) {
1637 StringAttr name = read<StringAttr>();
1638 if (Attribute attr = read<Attribute>())
1639 state.addAttribute(name, attr);
1640 }
1641
1642 // Read in the result types. If the "size" is the sentinel value, this
1643 // indicates that the result types should be inferred.
1644 unsigned numResults = read();
1645 if (numResults == kInferTypesMarker) {
1646 InferTypeOpInterface::Concept *inferInterface =
1647 state.name.getInterface<InferTypeOpInterface>();
1648 assert(inferInterface &&
1649 "expected operation to provide InferTypeOpInterface");
1650
1651 // TODO: Handle failure.
1652 if (failed(inferInterface->inferReturnTypes(
1653 state.getContext(), state.location, state.operands,
1654 state.attributes.getDictionary(state.getContext()),
1655 state.getRawProperties(), state.regions, state.types)))
1656 return;
1657 } else {
1658 // Otherwise, this is a fixed number of results.
1659 for (unsigned i = 0; i != numResults; ++i) {
1660 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1661 state.types.push_back(read<Type>());
1662 } else {
1663 TypeRange *resultTypes = read<TypeRange *>();
1664 state.types.append(resultTypes->begin(), resultTypes->end());
1665 }
1666 }
1667 }
1668
1669 Operation *resultOp = rewriter.create(state);
1670 memory[memIndex] = resultOp;
1671
1672 LDBG() << " * Attributes: "
1673 << state.attributes.getDictionary(state.getContext())
1674 << "\n * Operands: " << llvm::interleaved(state.operands)
1675 << "\n * Result Types: " << llvm::interleaved(state.types)
1676 << "\n * Result: " << *resultOp;
1677}
1678
1679template <typename T>
1680void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1681 LDBG() << "Executing CreateDynamic" << type << "Range:";
1682 unsigned memIndex = read();
1683 unsigned rangeIndex = read();
1684 SmallVector<T> values;
1685 readList(values);
1686
1687 LDBG() << " * " << type << "s: " << llvm::interleaved(values);
1688
1689 assignRangeToMemory(values, memIndex, rangeIndex);
1690}
1691
1692void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1693 LDBG() << "Executing EraseOp:";
1694 Operation *op = read<Operation *>();
1695
1696 LDBG() << " * Operation: " << *op;
1697 rewriter.eraseOp(op);
1698}
1699
1700template <typename T, typename Range, PDLValue::Kind kind>
1701void ByteCodeExecutor::executeExtract() {
1702 LDBG() << "Executing Extract" << kind << ":";
1703 Range *range = read<Range *>();
1704 unsigned index = read<uint32_t>();
1705 unsigned memIndex = read();
1706
1707 if (!range) {
1708 memory[memIndex] = nullptr;
1709 return;
1710 }
1711
1712 T result = index < range->size() ? (*range)[index] : T();
1713 LDBG() << " * " << kind << "s(" << range->size() << ")";
1714 LDBG() << " * Index: " << index;
1715 LDBG() << " * Result: " << result;
1716 storeToMemory(memIndex, result);
1717}
1718
1719void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
1720
1721void ByteCodeExecutor::executeForEach() {
1722 LDBG() << "Executing ForEach:";
1723 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1724 unsigned rangeIndex = read();
1725 unsigned memIndex = read();
1726 const void *value = nullptr;
1727
1728 switch (read<PDLValue::Kind>()) {
1729 case PDLValue::Kind::Operation: {
1730 unsigned &index = loopIndex[read()];
1731 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1732 assert(index <= array.size() && "iterated past the end");
1733 if (index < array.size()) {
1734 LDBG() << " * Result: " << array[index];
1735 value = array[index];
1736 break;
1737 }
1738
1739 LDBG() << " * Done";
1740 index = 0;
1741 selectJump(size_t(0));
1742 return;
1743 }
1744 default:
1745 llvm_unreachable("unexpected `ForEach` value kind");
1746 }
1747
1748 // Store the iterate value and the stack address.
1749 memory[memIndex] = value;
1750 pushCodeIt(prevCodeIt);
1751
1752 // Skip over the successor (we will enter the body of the loop).
1753 read<ByteCodeAddr>();
1754}
1755
1756void ByteCodeExecutor::executeGetAttribute() {
1757 LDBG() << "Executing GetAttribute:";
1758 unsigned memIndex = read();
1759 Operation *op = read<Operation *>();
1760 StringAttr attrName = read<StringAttr>();
1761 Attribute attr = op->getAttr(attrName);
1762
1763 LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
1764 << "\n * Result: " << attr;
1765 memory[memIndex] = attr.getAsOpaquePointer();
1766}
1767
1768void ByteCodeExecutor::executeGetAttributeType() {
1769 LDBG() << "Executing GetAttributeType:";
1770 unsigned memIndex = read();
1771 Attribute attr = read<Attribute>();
1772 Type type;
1773 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1774 type = typedAttr.getType();
1775
1776 LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
1777 memory[memIndex] = type.getAsOpaquePointer();
1778}
1779
1780void ByteCodeExecutor::executeGetDefiningOp() {
1781 LDBG() << "Executing GetDefiningOp:";
1782 unsigned memIndex = read();
1783 Operation *op = nullptr;
1784 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1785 Value value = read<Value>();
1786 if (value)
1787 op = value.getDefiningOp();
1788 LDBG() << " * Value: " << value;
1789 } else {
1790 ValueRange *values = read<ValueRange *>();
1791 if (values && !values->empty()) {
1792 op = values->front().getDefiningOp();
1793 }
1794 LDBG() << " * Values: " << values;
1795 }
1796
1797 LDBG() << " * Result: " << op;
1798 memory[memIndex] = op;
1799}
1800
1801void ByteCodeExecutor::executeGetOperand(unsigned index) {
1802 Operation *op = read<Operation *>();
1803 unsigned memIndex = read();
1804 Value operand =
1805 index < op->getNumOperands() ? op->getOperand(index) : Value();
1806
1807 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1808 << "\n * Result: " << operand;
1809 memory[memIndex] = operand.getAsOpaquePointer();
1810}
1811
1812/// This function is the internal implementation of `GetResults` and
1813/// `GetOperands` that provides support for extracting a value range from the
1814/// given operation.
1815template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1816static void *
1817executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1818 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1819 MutableArrayRef<ValueRange> valueRangeMemory) {
1820 // Check for the sentinel index that signals that all values should be
1821 // returned.
1822 if (index == std::numeric_limits<uint32_t>::max()) {
1823 LDBG() << " * Getting all values";
1824 // `values` is already the full value range.
1825
1826 // Otherwise, check to see if this operation uses AttrSizedSegments.
1827 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1828 LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
1829
1830 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1831 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1832 return nullptr;
1833
1834 ArrayRef<int32_t> segments = segmentAttr;
1835 unsigned startIndex = llvm::sum_of(segments.take_front(index));
1836 values = values.slice(startIndex, *std::next(segments.begin(), index));
1837
1838 LDBG() << " * Extracting range[" << startIndex << ", "
1839 << *std::next(segments.begin(), index) << "]";
1840
1841 // Otherwise, assume this is the last operand group of the operation.
1842 // FIXME: We currently don't support operations with
1843 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1844 // have a way to detect it's presence.
1845 } else if (values.size() >= index) {
1846 LDBG() << " * Treating values as trailing variadic range";
1847 values = values.drop_front(index);
1848
1849 // If we couldn't detect a way to compute the values, bail out.
1850 } else {
1851 return nullptr;
1852 }
1853
1854 // If the range index is valid, we are returning a range.
1855 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1856 valueRangeMemory[rangeIndex] = values;
1857 return &valueRangeMemory[rangeIndex];
1858 }
1859
1860 // If a range index wasn't provided, the range is required to be non-variadic.
1861 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1862}
1863
1864void ByteCodeExecutor::executeGetOperands() {
1865 LDBG() << "Executing GetOperands:";
1866 unsigned index = read<uint32_t>();
1867 Operation *op = read<Operation *>();
1868 ByteCodeField rangeIndex = read();
1869
1871 op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1872 valueRangeMemory);
1873 if (!result)
1874 LDBG() << " * Invalid operand range";
1875 memory[read()] = result;
1876}
1877
1878void ByteCodeExecutor::executeGetResult(unsigned index) {
1879 Operation *op = read<Operation *>();
1880 unsigned memIndex = read();
1881 OpResult result =
1882 index < op->getNumResults() ? op->getResult(index) : OpResult();
1883
1884 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1885 << "\n * Result: " << result;
1886 memory[memIndex] = result.getAsOpaquePointer();
1887}
1888
1889void ByteCodeExecutor::executeGetResults() {
1890 LDBG() << "Executing GetResults:";
1891 unsigned index = read<uint32_t>();
1892 Operation *op = read<Operation *>();
1893 ByteCodeField rangeIndex = read();
1894
1896 op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1897 valueRangeMemory);
1898 if (!result)
1899 LDBG() << " * Invalid result range";
1900 memory[read()] = result;
1901}
1902
1903void ByteCodeExecutor::executeGetUsers() {
1904 LDBG() << "Executing GetUsers:";
1905 unsigned memIndex = read();
1906 unsigned rangeIndex = read();
1907 std::vector<Operation *> &range = opRangeMemory[rangeIndex];
1908 memory[memIndex] = &range;
1909
1910 range.clear();
1911 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1912 // Read the value.
1913 Value value = read<Value>();
1914 if (!value)
1915 return;
1916 LDBG() << " * Value: " << value;
1917
1918 range.assign(value.user_begin(), value.user_end());
1919 } else {
1920 // Read a range of values.
1921 ValueRange *values = read<ValueRange *>();
1922 if (!values)
1923 return;
1924 LDBG() << " * Values (" << values->size()
1925 << "): " << llvm::interleaved(*values);
1926
1927 for (Value value : *values)
1928 range.insert(range.end(), value.user_begin(), value.user_end());
1929 }
1930
1931 LDBG() << " * Result: " << range.size() << " operations";
1932}
1933
1934void ByteCodeExecutor::executeGetValueType() {
1935 LDBG() << "Executing GetValueType:";
1936 unsigned memIndex = read();
1937 Value value = read<Value>();
1938 Type type = value ? value.getType() : Type();
1939
1940 LDBG() << " * Value: " << value << "\n * Result: " << type;
1941 memory[memIndex] = type.getAsOpaquePointer();
1942}
1943
1944void ByteCodeExecutor::executeGetValueRangeTypes() {
1945 LDBG() << "Executing GetValueRangeTypes:";
1946 unsigned memIndex = read();
1947 unsigned rangeIndex = read();
1948 ValueRange *values = read<ValueRange *>();
1949 if (!values) {
1950 LDBG() << " * Values: <NULL>";
1951 memory[memIndex] = nullptr;
1952 return;
1953 }
1954
1955 LDBG() << " * Values (" << values->size()
1956 << "): " << llvm::interleaved(*values)
1957 << "\n * Result: " << llvm::interleaved(values->getType());
1958 typeRangeMemory[rangeIndex] = values->getType();
1959 memory[memIndex] = &typeRangeMemory[rangeIndex];
1960}
1961
1962void ByteCodeExecutor::executeIsNotNull() {
1963 LDBG() << "Executing IsNotNull:";
1964 const void *value = read<const void *>();
1965
1966 LDBG() << " * Value: " << value;
1967 selectJump(value != nullptr);
1968}
1969
1970void ByteCodeExecutor::executeRecordMatch(
1971 PatternRewriter &rewriter,
1972 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1973 LDBG() << "Executing RecordMatch:";
1974 unsigned patternIndex = read();
1975 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1976 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1977
1978 // If the benefit of the pattern is impossible, skip the processing of the
1979 // rest of the pattern.
1980 if (benefit.isImpossibleToMatch()) {
1981 LDBG() << " * Benefit: Impossible To Match";
1982 curCodeIt = dest;
1983 return;
1984 }
1985
1986 // Create a fused location containing the locations of each of the
1987 // operations used in the match. This will be used as the location for
1988 // created operations during the rewrite that don't already have an
1989 // explicit location set.
1990 unsigned numMatchLocs = read();
1991 SmallVector<Location, 4> matchLocs;
1992 matchLocs.reserve(numMatchLocs);
1993 for (unsigned i = 0; i != numMatchLocs; ++i)
1994 matchLocs.push_back(read<Operation *>()->getLoc());
1995 Location matchLoc = rewriter.getFusedLoc(matchLocs);
1996
1997 LDBG() << " * Benefit: " << benefit.getBenefit();
1998 LDBG() << " * Location: " << matchLoc;
1999 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2000 PDLByteCode::MatchResult &match = matches.back();
2001
2002 // Record all of the inputs to the match. If any of the inputs are ranges, we
2003 // will also need to remap the range pointer to memory stored in the match
2004 // state.
2005 unsigned numInputs = read();
2006 match.values.reserve(numInputs);
2007 match.typeRangeValues.reserve(numInputs);
2008 match.valueRangeValues.reserve(numInputs);
2009 for (unsigned i = 0; i < numInputs; ++i) {
2010 switch (read<PDLValue::Kind>()) {
2011 case PDLValue::Kind::TypeRange:
2012 match.typeRangeValues.push_back(*read<TypeRange *>());
2013 match.values.push_back(&match.typeRangeValues.back());
2014 break;
2015 case PDLValue::Kind::ValueRange:
2016 match.valueRangeValues.push_back(*read<ValueRange *>());
2017 match.values.push_back(&match.valueRangeValues.back());
2018 break;
2019 default:
2020 match.values.push_back(read<const void *>());
2021 break;
2022 }
2023 }
2024 curCodeIt = dest;
2025}
2026
2027void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2028 LDBG() << "Executing ReplaceOp:";
2029 Operation *op = read<Operation *>();
2030 SmallVector<Value, 16> args;
2031 readList(args);
2032
2033 LDBG() << " * Operation: " << *op
2034 << "\n * Values: " << llvm::interleaved(args);
2035 rewriter.replaceOp(op, args);
2036}
2037
2038void ByteCodeExecutor::executeSwitchAttribute() {
2039 LDBG() << "Executing SwitchAttribute:";
2040 Attribute value = read<Attribute>();
2041 ArrayAttr cases = read<ArrayAttr>();
2042 handleSwitch(value, cases);
2043}
2044
2045void ByteCodeExecutor::executeSwitchOperandCount() {
2046 LDBG() << "Executing SwitchOperandCount:";
2047 Operation *op = read<Operation *>();
2048 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2049
2050 LDBG() << " * Operation: " << *op;
2051 handleSwitch(op->getNumOperands(), cases);
2052}
2053
2054void ByteCodeExecutor::executeSwitchOperationName() {
2055 LDBG() << "Executing SwitchOperationName:";
2056 OperationName value = read<Operation *>()->getName();
2057 size_t caseCount = read();
2058
2059 // The operation names are stored in-line, so to print them out for
2060 // debugging purposes we need to read the array before executing the
2061 // switch so that we can display all of the possible values.
2062 LLVM_DEBUG({
2063 const ByteCodeField *prevCodeIt = curCodeIt;
2064 LDBG() << " * Value: " << value << "\n * Cases: "
2065 << llvm::interleaved(
2066 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
2067 return read<OperationName>();
2068 }));
2069 curCodeIt = prevCodeIt;
2070 });
2071
2072 // Try to find the switch value within any of the cases.
2073 for (size_t i = 0; i != caseCount; ++i) {
2074 if (read<OperationName>() == value) {
2075 curCodeIt += (caseCount - i - 1);
2076 return selectJump(i + 1);
2077 }
2078 }
2079 selectJump(size_t(0));
2080}
2081
2082void ByteCodeExecutor::executeSwitchResultCount() {
2083 LDBG() << "Executing SwitchResultCount:";
2084 Operation *op = read<Operation *>();
2085 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2086
2087 LDBG() << " * Operation: " << *op;
2088 handleSwitch(op->getNumResults(), cases);
2089}
2090
2091void ByteCodeExecutor::executeSwitchType() {
2092 LDBG() << "Executing SwitchType:";
2093 Type value = read<Type>();
2094 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2095 handleSwitch(value, cases);
2096}
2097
2098void ByteCodeExecutor::executeSwitchTypes() {
2099 LDBG() << "Executing SwitchTypes:";
2100 TypeRange *value = read<TypeRange *>();
2101 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2102 if (!value) {
2103 LDBG() << "Types: <NULL>";
2104 return selectJump(size_t(0));
2105 }
2106 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2107 return value == caseValue.getAsValueRange<TypeAttr>();
2108 });
2109}
2110
2111LogicalResult
2112ByteCodeExecutor::execute(PatternRewriter &rewriter,
2113 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2114 std::optional<Location> mainRewriteLoc) {
2115 while (true) {
2116 // Print the location of the operation being executed.
2117 LDBG() << readInline<Location>();
2118
2119 OpCode opCode = static_cast<OpCode>(read());
2120 switch (opCode) {
2121 case ApplyConstraint:
2122 executeApplyConstraint(rewriter);
2123 break;
2124 case ApplyRewrite:
2125 if (failed(executeApplyRewrite(rewriter)))
2126 return failure();
2127 break;
2128 case AreEqual:
2129 executeAreEqual();
2130 break;
2131 case AreRangesEqual:
2132 executeAreRangesEqual();
2133 break;
2134 case Branch:
2135 executeBranch();
2136 break;
2137 case CheckOperandCount:
2138 executeCheckOperandCount();
2139 break;
2140 case CheckOperationName:
2141 executeCheckOperationName();
2142 break;
2143 case CheckResultCount:
2144 executeCheckResultCount();
2145 break;
2146 case CheckTypes:
2147 executeCheckTypes();
2148 break;
2149 case Continue:
2150 executeContinue();
2151 break;
2152 case CreateConstantTypeRange:
2153 executeCreateConstantTypeRange();
2154 break;
2155 case CreateOperation:
2156 executeCreateOperation(rewriter, *mainRewriteLoc);
2157 break;
2158 case CreateDynamicTypeRange:
2159 executeDynamicCreateRange<Type>("Type");
2160 break;
2161 case CreateDynamicValueRange:
2162 executeDynamicCreateRange<Value>("Value");
2163 break;
2164 case EraseOp:
2165 executeEraseOp(rewriter);
2166 break;
2167 case ExtractOp:
2168 executeExtract<Operation *, std::vector<Operation *>,
2169 PDLValue::Kind::Operation>();
2170 break;
2171 case ExtractType:
2172 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2173 break;
2174 case ExtractValue:
2175 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2176 break;
2177 case Finalize:
2178 executeFinalize();
2179 LDBG() << "";
2180 return success();
2181 case ForEach:
2182 executeForEach();
2183 break;
2184 case GetAttribute:
2185 executeGetAttribute();
2186 break;
2187 case GetAttributeType:
2188 executeGetAttributeType();
2189 break;
2190 case GetDefiningOp:
2191 executeGetDefiningOp();
2192 break;
2193 case GetOperand0:
2194 case GetOperand1:
2195 case GetOperand2:
2196 case GetOperand3: {
2197 unsigned index = opCode - GetOperand0;
2198 LDBG() << "Executing GetOperand" << index << ":";
2199 executeGetOperand(index);
2200 break;
2201 }
2202 case GetOperandN:
2203 LDBG() << "Executing GetOperandN:";
2204 executeGetOperand(read<uint32_t>());
2205 break;
2206 case GetOperands:
2207 executeGetOperands();
2208 break;
2209 case GetResult0:
2210 case GetResult1:
2211 case GetResult2:
2212 case GetResult3: {
2213 unsigned index = opCode - GetResult0;
2214 LDBG() << "Executing GetResult" << index << ":";
2215 executeGetResult(index);
2216 break;
2217 }
2218 case GetResultN:
2219 LDBG() << "Executing GetResultN:";
2220 executeGetResult(read<uint32_t>());
2221 break;
2222 case GetResults:
2223 executeGetResults();
2224 break;
2225 case GetUsers:
2226 executeGetUsers();
2227 break;
2228 case GetValueType:
2229 executeGetValueType();
2230 break;
2231 case GetValueRangeTypes:
2232 executeGetValueRangeTypes();
2233 break;
2234 case IsNotNull:
2235 executeIsNotNull();
2236 break;
2237 case RecordMatch:
2238 assert(matches &&
2239 "expected matches to be provided when executing the matcher");
2240 executeRecordMatch(rewriter, *matches);
2241 break;
2242 case ReplaceOp:
2243 executeReplaceOp(rewriter);
2244 break;
2245 case SwitchAttribute:
2246 executeSwitchAttribute();
2247 break;
2248 case SwitchOperandCount:
2249 executeSwitchOperandCount();
2250 break;
2251 case SwitchOperationName:
2252 executeSwitchOperationName();
2253 break;
2254 case SwitchResultCount:
2255 executeSwitchResultCount();
2256 break;
2257 case SwitchType:
2258 executeSwitchType();
2259 break;
2260 case SwitchTypes:
2261 executeSwitchTypes();
2262 break;
2263 }
2264 LDBG() << "";
2265 }
2266}
2267
2268void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2269 SmallVectorImpl<MatchResult> &matches,
2270 PDLByteCodeMutableState &state) const {
2271 // The first memory slot is always the root operation.
2272 state.memory[0] = op;
2273
2274 // The matcher function always starts at code address 0.
2275 ByteCodeExecutor executor(
2276 matcherByteCode.data(), state.memory, state.opRangeMemory,
2277 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2278 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2279 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2280 constraintFunctions, rewriteFunctions);
2281 LogicalResult executeResult = executor.execute(rewriter, &matches);
2282 (void)executeResult;
2283 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2284
2285 // Order the found matches by benefit.
2286 llvm::stable_sort(matches,
2287 [](const MatchResult &lhs, const MatchResult &rhs) {
2288 return lhs.benefit > rhs.benefit;
2289 });
2290}
2291
2292LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2293 const MatchResult &match,
2294 PDLByteCodeMutableState &state) const {
2295 auto *configSet = match.pattern->getConfigSet();
2296 if (configSet)
2297 configSet->notifyRewriteBegin(rewriter);
2298
2299 // The arguments of the rewrite function are stored at the start of the
2300 // memory buffer.
2301 llvm::copy(match.values, state.memory.begin());
2302
2303 ByteCodeExecutor executor(
2304 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2305 state.opRangeMemory, state.typeRangeMemory,
2306 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2307 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2308 rewriterByteCode, state.currentPatternBenefits, patterns,
2309 constraintFunctions, rewriteFunctions);
2310 LogicalResult result =
2311 executor.execute(rewriter, /*matches=*/nullptr, match.location);
2312
2313 if (configSet)
2314 configSet->notifyRewriteEnd(rewriter);
2315
2316 // If the rewrite failed, check if the pattern rewriter can recover. If it
2317 // can, we can signal to the pattern applicator to keep trying patterns. If it
2318 // doesn't, we need to bail. Bailing here should be fine, given that we have
2319 // no means to propagate such a failure to the user, and it also indicates a
2320 // bug in the user code (i.e. failable rewrites should not be used with
2321 // pattern rewriters that don't support it).
2322 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2323 LDBG() << " and rollback is not supported - aborting";
2324 llvm::report_fatal_error(
2325 "Native PDL Rewrite failed, but the pattern "
2326 "rewriter doesn't support recovery. Failable pattern rewrites should "
2327 "not be used with pattern rewriters that do not support them.");
2328 }
2329 return result;
2330}
for(Operation *op :ops)
return success()
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition ByteCode.cpp:175
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
lhs
ArrayAttr()
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition Attributes.h:73
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
BlockArgListType getArguments()
Definition Block.h:87
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition Liveness.cpp:375
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Types.h:164
type_range getType() const
Type getType() const
Return the type of this value.
Definition Value.h:105
user_iterator user_begin() const
Definition Value.h:216
user_iterator user_end() const
Definition Value.h:217
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition ByteCode.h:235
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition ByteCode.h:234
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition ByteCode.h:248
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition ByteCode.h:247
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition ByteCode.h:251
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult size