MLIR 22.0.0git
ByteCode.cpp
Go to the documentation of this file.
1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
17#include "mlir/IR/BuiltinOps.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/Format.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/InterleavedRange.h"
27#include <numeric>
28#include <optional>
29
30#define DEBUG_TYPE "pdl-bytecode"
31
32using namespace mlir;
33using namespace mlir::detail;
34
35//===----------------------------------------------------------------------===//
36// PDLByteCodePattern
37//===----------------------------------------------------------------------===//
38
39PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
42 PatternBenefit benefit = matchOp.getBenefit();
43 MLIRContext *ctx = matchOp.getContext();
44
45 // Collect the set of generated operations.
46 SmallVector<StringRef, 8> generatedOps;
47 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
48 generatedOps =
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50
51 // Check to see if this is pattern matches a specific operation type.
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
53 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
54 generatedOps);
55 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
56 benefit, ctx, generatedOps);
57}
58
59//===----------------------------------------------------------------------===//
60// PDLByteCodeMutableState
61//===----------------------------------------------------------------------===//
62
63/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
64/// to the position of the pattern within the range returned by
65/// `PDLByteCode::getPatterns`.
66void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
67 PatternBenefit benefit) {
68 currentPatternBenefits[patternIndex] = benefit;
69}
70
71/// Cleanup any allocated state after a full match/rewrite has been completed.
72/// This method should be called irregardless of whether the match+rewrite was a
73/// success or not.
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
77}
78
79//===----------------------------------------------------------------------===//
80// Bytecode OpCodes
81//===----------------------------------------------------------------------===//
82
83namespace {
84enum OpCode : ByteCodeField {
85 /// Apply an externally registered constraint.
86 ApplyConstraint,
87 /// Apply an externally registered rewrite.
88 ApplyRewrite,
89 /// Check if two generic values are equal.
90 AreEqual,
91 /// Check if two ranges are equal.
92 AreRangesEqual,
93 /// Unconditional branch.
94 Branch,
95 /// Compare the operand count of an operation with a constant.
96 CheckOperandCount,
97 /// Compare the name of an operation with a constant.
98 CheckOperationName,
99 /// Compare the result count of an operation with a constant.
100 CheckResultCount,
101 /// Compare a range of types to a constant range of types.
102 CheckTypes,
103 /// Continue to the next iteration of a loop.
104 Continue,
105 /// Create a type range from a list of constant types.
106 CreateConstantTypeRange,
107 /// Create an operation.
108 CreateOperation,
109 /// Create a type range from a list of dynamic types.
110 CreateDynamicTypeRange,
111 /// Create a value range.
112 CreateDynamicValueRange,
113 /// Erase an operation.
114 EraseOp,
115 /// Extract the op from a range at the specified index.
116 ExtractOp,
117 /// Extract the type from a range at the specified index.
118 ExtractType,
119 /// Extract the value from a range at the specified index.
120 ExtractValue,
121 /// Terminate a matcher or rewrite sequence.
122 Finalize,
123 /// Iterate over a range of values.
124 ForEach,
125 /// Get a specific attribute of an operation.
126 GetAttribute,
127 /// Get the type of an attribute.
128 GetAttributeType,
129 /// Get the defining operation of a value.
130 GetDefiningOp,
131 /// Get a specific operand of an operation.
132 GetOperand0,
133 GetOperand1,
134 GetOperand2,
135 GetOperand3,
136 GetOperandN,
137 /// Get a specific operand group of an operation.
138 GetOperands,
139 /// Get a specific result of an operation.
140 GetResult0,
141 GetResult1,
142 GetResult2,
143 GetResult3,
144 GetResultN,
145 /// Get a specific result group of an operation.
146 GetResults,
147 /// Get the users of a value or a range of values.
148 GetUsers,
149 /// Get the type of a value.
150 GetValueType,
151 /// Get the types of a value range.
152 GetValueRangeTypes,
153 /// Check if a generic value is not null.
154 IsNotNull,
155 /// Record a successful pattern match.
156 RecordMatch,
157 /// Replace an operation.
158 ReplaceOp,
159 /// Compare an attribute with a set of constants.
160 SwitchAttribute,
161 /// Compare the operand count of an operation with a set of constants.
162 SwitchOperandCount,
163 /// Compare the name of an operation with a set of constants.
164 SwitchOperationName,
165 /// Compare the result count of an operation with a set of constants.
166 SwitchResultCount,
167 /// Compare a type with a set of constants.
168 SwitchType,
169 /// Compare a range of types with a set of constants.
170 SwitchTypes,
171};
172} // namespace
173
174/// A marker used to indicate if an operation should infer types.
175static constexpr ByteCodeField kInferTypesMarker =
176 std::numeric_limits<ByteCodeField>::max();
177
178//===----------------------------------------------------------------------===//
179// ByteCode Generation
180//===----------------------------------------------------------------------===//
181
182//===----------------------------------------------------------------------===//
183// Generator
184//===----------------------------------------------------------------------===//
185
186namespace {
187struct ByteCodeLiveRange;
188struct ByteCodeWriter;
189
190/// Check if the given class `T` can be converted to an opaque pointer.
191template <typename T, typename... Args>
192using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
193
194/// This class represents the main generator for the pattern bytecode.
195class Generator {
196public:
197 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
198 SmallVectorImpl<ByteCodeField> &matcherByteCode,
199 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
210 rewriterByteCode(rewriterByteCode), patterns(patterns),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
216 for (const auto &it : llvm::enumerate(constraintFns))
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
218 for (const auto &it : llvm::enumerate(rewriteFns))
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 }
221
222 /// Generate the bytecode for the given PDL interpreter module.
223 void generate(ModuleOp module);
224
225 /// Return the memory index to use for the given value.
226 ByteCodeField &getMemIndex(Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
230 }
231
232 /// Return the range memory index used to store the given range value.
233 ByteCodeField &getRangeStorageIndex(Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
237 }
238
239 /// Return an index to use when referring to the given data that is uniqued in
240 /// the MLIR context.
241 template <typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
243 getMemIndex(T val) {
244 const void *opaqueVal = val.getAsOpaquePointer();
245
246 // Get or insert a reference to this value.
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
249 if (it.second)
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
252 }
253
254private:
255 /// Allocate memory indices for the results of operations within the matcher
256 /// and rewriters.
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
259
260 /// Generate the bytecode for the given operation.
261 void generate(Region *region, ByteCodeWriter &writer);
262 void generate(Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
301
302 /// Mapping from value to its corresponding memory index.
303 DenseMap<Value, ByteCodeField> valueToMemIndex;
304
305 /// Mapping from a range value to its corresponding range storage index.
306 DenseMap<Value, ByteCodeField> valueToRangeIndex;
307
308 /// Mapping from the name of an externally registered rewrite to its index in
309 /// the bytecode registry.
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311
312 /// Mapping from the name of an externally registered constraint to its index
313 /// in the bytecode registry.
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315
316 /// Mapping from rewriter function name to the bytecode address of the
317 /// rewriter function in byte.
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
319
320 /// Mapping from a uniqued storage object to its memory index within
321 /// `uniquedData`.
322 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
323
324 /// The current level of the foreach loop.
325 ByteCodeField curLoopLevel = 0;
326
327 /// The current MLIR context.
328 MLIRContext *ctx;
329
330 /// Mapping from block to its address.
332
333 /// Data of the ByteCode class to be populated.
334 std::vector<const void *> &uniquedData;
335 SmallVectorImpl<ByteCodeField> &matcherByteCode;
336 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
337 SmallVectorImpl<PDLByteCodePattern> &patterns;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
343
344 /// A map of pattern configurations.
346};
347
348/// This class provides utilities for writing a bytecode stream.
349struct ByteCodeWriter {
350 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
351 : bytecode(bytecode), generator(generator) {}
352
353 /// Append a field to the bytecode.
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
356
357 /// Append an address to the bytecode.
358 void append(ByteCodeAddr field) {
359 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
361
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
365 }
366
367 /// Append a single successor to the bytecode, the exact address will need to
368 /// be resolved later.
369 void append(Block *successor) {
370 // Add back a reference to the successor so that the address can be resolved
371 // later.
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
374 }
375
376 /// Append a successor range to the bytecode, the exact address will need to
377 /// be resolved later.
378 void append(SuccessorRange successors) {
379 for (Block *successor : successors)
380 append(successor);
381 }
382
383 /// Append a range of values that will be read as generic PDLValues.
384 void appendPDLValueList(OperandRange values) {
385 bytecode.push_back(values.size());
386 for (Value value : values)
387 appendPDLValue(value);
388 }
389
390 /// Append a value as a PDLValue.
391 void appendPDLValue(Value value) {
392 appendPDLValueKind(value);
393 append(value);
394 }
395
396 /// Append the PDLValue::Kind of the given value.
397 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
398
399 /// Append the PDLValue::Kind of the given type.
400 void appendPDLValueKind(Type type) {
401 PDLValue::Kind kind =
403 .Case<pdl::AttributeType>(
404 [](Type) { return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](Type) { return PDLValue::Kind::Operation; })
407 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
411 })
412 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
414 bytecode.push_back(static_cast<ByteCodeField>(kind));
415 }
416
417 /// Append a value that will be stored in a memory slot and not inline within
418 /// the bytecode.
419 template <typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
422 append(T value) {
423 bytecode.push_back(generator.getMemIndex(value));
424 }
425
426 /// Append a range of values.
427 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
429 append(T range) {
430 bytecode.push_back(llvm::size(range));
431 for (auto it : range)
432 append(it);
433 }
434
435 /// Append a variadic number of fields to the bytecode.
436 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
438 append(field);
439 append(field2, fields...);
440 }
441
442 /// Appends a value as a pointer, stored inline within the bytecode.
443 template <typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer, sizeof(const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
451 }
452
453 /// Successor references in the bytecode that have yet to be resolved.
454 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
455
456 /// The underlying bytecode buffer.
457 SmallVectorImpl<ByteCodeField> &bytecode;
458
459 /// The main generator producing PDL.
460 Generator &generator;
461};
462
463/// This class represents a live range of PDL Interpreter values, containing
464/// information about when values are live within a match/rewrite.
465struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
468
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470
471 /// Union this live range with the one provided.
472 void unionWith(const ByteCodeLiveRange &rhs) {
473 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
474 ++it)
475 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
476 }
477
478 /// Returns true if this range overlaps with the one provided.
479 bool overlaps(const ByteCodeLiveRange &rhs) const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
481 .valid();
482 }
483
484 /// A map representing the ranges of the match/rewrite that a value is live in
485 /// the interpreter.
486 ///
487 /// We use std::unique_ptr here, because IntervalMap does not provide a
488 /// correct copy or move constructor. We can eliminate the pointer once
489 /// https://reviews.llvm.org/D113240 lands.
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491
492 /// The operation range storage index for this range.
493 std::optional<unsigned> opRangeIndex;
494
495 /// The type range storage index for this range.
496 std::optional<unsigned> typeRangeIndex;
497
498 /// The value range storage index for this range.
499 std::optional<unsigned> valueRangeIndex;
500};
501} // namespace
502
503void Generator::generate(ModuleOp module) {
504 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
509
510 // Allocate memory indices for the results of operations within the matcher
511 // and rewriters.
512 allocateMemoryIndices(matcherFunc, rewriterModule);
513
514 // Generate code for the rewriter functions.
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
516 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
520 }
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
523
524 // Generate code for the matcher function.
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527
528 // Resolve successor references in the matcher.
529 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
533 }
534}
535
536void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
538 // Rewriters use simplistic allocation scheme that simply assigns an index to
539 // each result.
540 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
550 }
551 };
552
553 for (BlockArgument arg : rewriterFunc.getArguments())
554 processRewriterValue(arg);
555 rewriterFunc.getBody().walk([&](Operation *op) {
556 for (Value result : op->getResults())
557 processRewriterValue(result);
558 });
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
565 }
566
567 // The matcher function uses a more sophisticated numbering that tries to
568 // minimize the number of memory indices assigned. This is done by determining
569 // a live range of the values within the matcher, then the allocation is just
570 // finding the minimal number of overlapping live ranges. This is essentially
571 // a simplified form of register allocation where we don't necessarily have a
572 // limited number of registers, but we still want to minimize the number used.
573 DenseMap<Operation *, unsigned> opToFirstIndex;
575
576 // A custom walk that marks the first and the last index of each operation.
577 // The entry marks the beginning of the liveness range for this operation,
578 // followed by nested operations, followed by the end of the liveness range.
579 unsigned index = 0;
580 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
581 opToFirstIndex.try_emplace(op, index++);
582 for (Region &region : op->getRegions())
583 for (Block &block : region.getBlocks())
584 for (Operation &nested : block)
585 walk(&nested);
586 opToLastIndex.try_emplace(op, index++);
587 };
588 walk(matcherFunc);
589
590 // Liveness info for each of the defs within the matcher.
591 ByteCodeLiveRange::Allocator allocator;
593
594 // Assign the root operation being matched to slot 0.
595 BlockArgument rootOpArg = matcherFunc.getArgument(0);
596 valueToMemIndex[rootOpArg] = 0;
597
598 // Walk each of the blocks, computing the def interval that the value is used.
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](Block *block) {
601 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602 assert(info && "expected liveness info for block");
603 auto processValue = [&](Value value, Operation *firstUseOrDef) {
604 // We don't need to process the root op argument, this value is always
605 // assigned to the first memory slot.
606 if (value == rootOpArg)
607 return;
608
609 // Set indices for the range of this block that the value is used.
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
613 opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
614 /*dummyValue*/ 0);
615
616 // Check to see if this value is a range type.
617 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
625 }
626 };
627
628 // Process the live-ins of this block.
629 for (Value liveIn : info->in()) {
630 // Only process the value if it has been defined in the current region.
631 // Other values that span across pdl_interp.foreach will be added higher
632 // up. This ensures that the we keep them alive for the entire duration
633 // of the loop.
634 if (liveIn.getParentRegion() == block->getParent())
635 processValue(liveIn, &block->front());
636 }
637
638 // Process the block arguments for the entry block (those are not live-in).
639 if (block->isEntryBlock()) {
640 for (Value argument : block->getArguments())
641 processValue(argument, &block->front());
642 }
643
644 // Process any new defs within this block.
645 for (Operation &op : *block)
646 for (Value result : op.getResults())
647 processValue(result, &op);
648 });
649
650 // Greedily allocate memory slots using the computed def live ranges.
651 std::vector<ByteCodeLiveRange> allocatedIndices;
652
653 // The number of memory indices currently allocated (and its next value).
654 // Recall that the root gets allocated memory index 0.
655 ByteCodeField numIndices = 1;
656
657 // The number of memory ranges of various types (and their next values).
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
659
660 for (auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
663
664 // Try to allocate to an existing index.
665 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
670
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 } else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 } else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
683 }
684 break;
685 }
686 }
687
688 // If no existing index could be used, add a new one.
689 if (memIndex == 0) {
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
693
694 // Allocate an index for op/type/value ranges.
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 } else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 } else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
704 }
705
706 memIndex = allocatedIndices.size();
707 ++numIndices;
708 }
709 }
710
711 // Print the index usage and ensure that we did not run out of index space.
712 LDBG() << "Allocated " << allocatedIndices.size() << " indices "
713 << "(down from initial " << valueDefRanges.size() << ").";
714 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715 "Ran out of memory for allocated indices");
716
717 // Update the max number of indices.
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
726}
727
728void Generator::generate(Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (Block *block : rpot) {
731 // Keep track of where this block begins within the matcher function.
732 blockToAddr.try_emplace(block, matcherByteCode.size());
733 for (Operation &op : *block)
734 generate(&op, writer);
735 }
736}
737
738void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739 LDBG() << "Generating bytecode for operation: " << op->getName();
740 LLVM_DEBUG({
741 // The following list must contain all the operations that do not
742 // produce any bytecode.
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->getLoc());
745 });
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](auto interpOp) { this->generate(interpOp, writer); })
767 .DefaultUnreachable("unknown `pdl_interp` operation");
768}
769
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 // Constraints that should return a value have to be registered as rewrites.
773 // If a constraint and a rewrite of similar name are registered the
774 // constraint takes precedence
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
778 ResultRange results = op.getResults();
779 writer.append(ByteCodeField(results.size()));
780 for (Value result : results) {
781 // We record the expected kind of the result, so that we can provide extra
782 // verification of the native rewrite function and handle the failure case
783 // of constraints accordingly.
784 writer.appendPDLValueKind(result);
785
786 // Range results also need to append the range storage index.
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
790 }
791 writer.append(op.getSuccessors());
792}
793void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
799
800 ResultRange results = op.getResults();
801 writer.append(ByteCodeField(results.size()));
802 for (Value result : results) {
803 // We record the expected kind of the result, so that we
804 // can provide extra verification of the native rewrite function.
805 writer.appendPDLValueKind(result);
806
807 // Range results also need to append the range storage index.
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
811 }
812}
813void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819 return;
820 }
821
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823}
824void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826}
827void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830 op.getSuccessors());
831}
832void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField>(op.getCompareAtLeast()),
836 op.getSuccessors());
837}
838void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
841 OperationName(op.getName(), ctx), op.getSuccessors());
842}
843void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField>(op.getCompareAtLeast()),
847 op.getSuccessors());
848}
849void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851 op.getSuccessors());
852}
853void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855 op.getSuccessors());
856}
857void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860}
861void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
863 // Simply repoint the memory index of the result to the constant.
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865}
866void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
869 OperationName(op.getName(), ctx));
870 writer.appendPDLValueList(op.getInputOperands());
871
872 // Add the attributes.
873 OperandRange attributes = op.getInputAttributes();
874 writer.append(static_cast<ByteCodeField>(attributes.size()));
875 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
877
878 // Add the result types. If the operation has inferred results, we use a
879 // marker "size" value. Otherwise, we add the list of explicit result types.
880 if (op.getInferredResultTypes())
881 writer.append(kInferTypesMarker);
882 else
883 writer.appendPDLValueList(op.getInputResultTypes());
884}
885void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886 // Append the correct opcode for the range type.
887 TypeSwitch<Type>(op.getType().getElementType())
888 .Case(
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
892 });
893
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
896}
897void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898 // Simply repoint the memory index of the result to the constant.
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900}
901void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
904}
905void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
907}
908void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909 OpCode opCode =
910 TypeSwitch<Type, OpCode>(op.getResult().getType())
911 .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914 .DefaultUnreachable("unsupported element type");
915 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
916}
917void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918 writer.append(OpCode::Finalize);
919}
920void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
921 BlockArgument arg = op.getLoopVariable();
922 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923 writer.appendPDLValueKind(arg.getType());
924 writer.append(curLoopLevel, op.getSuccessor());
925 ++curLoopLevel;
926 if (curLoopLevel > maxLoopLevel)
927 maxLoopLevel = curLoopLevel;
928 generate(&op.getRegion(), writer);
929 --curLoopLevel;
930}
931void Generator::generate(pdl_interp::GetAttributeOp op,
932 ByteCodeWriter &writer) {
933 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
934 op.getNameAttr());
935}
936void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937 ByteCodeWriter &writer) {
938 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
939}
940void Generator::generate(pdl_interp::GetDefiningOpOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetDefiningOp, op.getInputOp());
943 writer.appendPDLValue(op.getValue());
944}
945void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946 uint32_t index = op.getIndex();
947 if (index < 4)
948 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
949 else
950 writer.append(OpCode::GetOperandN, index);
951 writer.append(op.getInputOp(), op.getValue());
952}
953void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954 Value result = op.getValue();
955 std::optional<uint32_t> index = op.getIndex();
956 writer.append(OpCode::GetOperands,
957 index.value_or(std::numeric_limits<uint32_t>::max()),
958 op.getInputOp());
959 if (isa<pdl::RangeType>(result.getType()))
960 writer.append(getRangeStorageIndex(result));
961 else
962 writer.append(std::numeric_limits<ByteCodeField>::max());
963 writer.append(result);
964}
965void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966 uint32_t index = op.getIndex();
967 if (index < 4)
968 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
969 else
970 writer.append(OpCode::GetResultN, index);
971 writer.append(op.getInputOp(), op.getValue());
972}
973void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974 Value result = op.getValue();
975 std::optional<uint32_t> index = op.getIndex();
976 writer.append(OpCode::GetResults,
977 index.value_or(std::numeric_limits<uint32_t>::max()),
978 op.getInputOp());
979 if (isa<pdl::RangeType>(result.getType()))
980 writer.append(getRangeStorageIndex(result));
981 else
982 writer.append(std::numeric_limits<ByteCodeField>::max());
983 writer.append(result);
984}
985void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986 Value operations = op.getOperations();
987 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988 writer.append(OpCode::GetUsers, operations, rangeIndex);
989 writer.appendPDLValue(op.getValue());
990}
991void Generator::generate(pdl_interp::GetValueTypeOp op,
992 ByteCodeWriter &writer) {
993 if (isa<pdl::RangeType>(op.getType())) {
994 Value result = op.getResult();
995 writer.append(OpCode::GetValueRangeTypes, result,
996 getRangeStorageIndex(result), op.getValue());
997 } else {
998 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
999 }
1000}
1001void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1003}
1004void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005 ByteCodeField patternIndex = patterns.size();
1006 patterns.emplace_back(PDLByteCodePattern::create(
1007 op, configMap.lookup(op),
1008 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009 writer.append(OpCode::RecordMatch, patternIndex,
1010 SuccessorRange(op.getOperation()), op.getMatchedOps());
1011 writer.appendPDLValueList(op.getInputs());
1012}
1013void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014 writer.append(OpCode::ReplaceOp, op.getInputOp());
1015 writer.appendPDLValueList(op.getReplValues());
1016}
1017void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018 ByteCodeWriter &writer) {
1019 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020 op.getCaseValuesAttr(), op.getSuccessors());
1021}
1022void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023 ByteCodeWriter &writer) {
1024 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025 op.getCaseValuesAttr(), op.getSuccessors());
1026}
1027void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028 ByteCodeWriter &writer) {
1029 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1030 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1031 });
1032 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033 op.getSuccessors());
1034}
1035void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036 ByteCodeWriter &writer) {
1037 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038 op.getCaseValuesAttr(), op.getSuccessors());
1039}
1040void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042 op.getSuccessors());
1043}
1044void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// PDLByteCode
1051//===----------------------------------------------------------------------===//
1052
1053PDLByteCode::PDLByteCode(
1054 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056 llvm::StringMap<PDLConstraintFunction> constraintFns,
1057 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058 : configs(std::move(configs)) {
1059 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1060 rewriterByteCode, patterns, maxValueMemoryIndex,
1061 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062 maxLoopLevel, constraintFns, rewriteFns, configMap);
1063 generator.generate(module);
1064
1065 // Initialize the external functions.
1066 for (auto &it : constraintFns)
1067 constraintFunctions.push_back(std::move(it.second));
1068 for (auto &it : rewriteFns)
1069 rewriteFunctions.push_back(std::move(it.second));
1070}
1071
1072/// Initialize the given state such that it can be used to execute the current
1073/// bytecode.
1075 state.memory.resize(maxValueMemoryIndex, nullptr);
1076 state.opRangeMemory.resize(maxOpRangeCount);
1077 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1078 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1079 state.loopIndex.resize(maxLoopLevel, 0);
1080 state.currentPatternBenefits.reserve(patterns.size());
1081 for (const PDLByteCodePattern &pattern : patterns)
1082 state.currentPatternBenefits.push_back(pattern.getBenefit());
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// ByteCode Execution
1087//===----------------------------------------------------------------------===//
1088
1089namespace {
1090/// This class is an instantiation of the PDLResultList that provides access to
1091/// the returned results. This API is not on `PDLResultList` to avoid
1092/// overexposing access to information specific solely to the ByteCode.
1093class ByteCodeRewriteResultList : public PDLResultList {
1094public:
1095 ByteCodeRewriteResultList(unsigned maxNumResults)
1096 : PDLResultList(maxNumResults) {}
1097
1098 /// Return the list of PDL results.
1099 MutableArrayRef<PDLValue> getResults() { return results; }
1100
1101 /// Return the type ranges allocated by this list.
1102 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1103 return allocatedTypeRanges;
1104 }
1105
1106 /// Return the value ranges allocated by this list.
1107 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1108 return allocatedValueRanges;
1109 }
1110};
1111
1112/// This class provides support for executing a bytecode stream.
1113class ByteCodeExecutor {
1114public:
1115 ByteCodeExecutor(
1116 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1117 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1118 MutableArrayRef<TypeRange> typeRangeMemory,
1119 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1120 MutableArrayRef<ValueRange> valueRangeMemory,
1121 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1122 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1123 ArrayRef<ByteCodeField> code,
1124 ArrayRef<PatternBenefit> currentPatternBenefits,
1125 ArrayRef<PDLByteCodePattern> patterns,
1126 ArrayRef<PDLConstraintFunction> constraintFunctions,
1127 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1128 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1129 typeRangeMemory(typeRangeMemory),
1130 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1131 valueRangeMemory(valueRangeMemory),
1132 allocatedValueRangeMemory(allocatedValueRangeMemory),
1133 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1134 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1135 constraintFunctions(constraintFunctions),
1136 rewriteFunctions(rewriteFunctions) {}
1137
1138 /// Start executing the code at the current bytecode index. `matches` is an
1139 /// optional field provided when this function is executed in a matching
1140 /// context.
1141 LogicalResult
1142 execute(PatternRewriter &rewriter,
1143 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1144 std::optional<Location> mainRewriteLoc = {});
1145
1146private:
1147 /// Internal implementation of executing each of the bytecode commands.
1148 void executeApplyConstraint(PatternRewriter &rewriter);
1149 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1150 void executeAreEqual();
1151 void executeAreRangesEqual();
1152 void executeBranch();
1153 void executeCheckOperandCount();
1154 void executeCheckOperationName();
1155 void executeCheckResultCount();
1156 void executeCheckTypes();
1157 void executeContinue();
1158 void executeCreateConstantTypeRange();
1159 void executeCreateOperation(PatternRewriter &rewriter,
1160 Location mainRewriteLoc);
1161 template <typename T>
1162 void executeDynamicCreateRange(StringRef type);
1163 void executeEraseOp(PatternRewriter &rewriter);
1164 template <typename T, typename Range, PDLValue::Kind kind>
1165 void executeExtract();
1166 void executeFinalize();
1167 void executeForEach();
1168 void executeGetAttribute();
1169 void executeGetAttributeType();
1170 void executeGetDefiningOp();
1171 void executeGetOperand(unsigned index);
1172 void executeGetOperands();
1173 void executeGetResult(unsigned index);
1174 void executeGetResults();
1175 void executeGetUsers();
1176 void executeGetValueType();
1177 void executeGetValueRangeTypes();
1178 void executeIsNotNull();
1179 void executeRecordMatch(PatternRewriter &rewriter,
1180 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1181 void executeReplaceOp(PatternRewriter &rewriter);
1182 void executeSwitchAttribute();
1183 void executeSwitchOperandCount();
1184 void executeSwitchOperationName();
1185 void executeSwitchResultCount();
1186 void executeSwitchType();
1187 void executeSwitchTypes();
1188 void processNativeFunResults(ByteCodeRewriteResultList &results,
1189 unsigned numResults,
1190 LogicalResult &rewriteResult);
1191
1192 /// Pushes a code iterator to the stack.
1193 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1194
1195 /// Pops a code iterator from the stack, returning true on success.
1196 void popCodeIt() {
1197 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1198 curCodeIt = resumeCodeIt.pop_back_val();
1199 }
1200
1201 /// Return the bytecode iterator at the start of the current op code.
1202 const ByteCodeField *getPrevCodeIt() const {
1203 LLVM_DEBUG({
1204 // Account for the op code and the Location stored inline.
1205 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1206 });
1207
1208 // Account for the op code only.
1209 return curCodeIt - 1;
1210 }
1211
1212 /// Read a value from the bytecode buffer, optionally skipping a certain
1213 /// number of prefix values. These methods always update the buffer to point
1214 /// to the next field after the read data.
1215 template <typename T = ByteCodeField>
1216 T read(size_t skipN = 0) {
1217 curCodeIt += skipN;
1218 return readImpl<T>();
1219 }
1220 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1221
1222 /// Read a list of values from the bytecode buffer.
1223 template <typename ValueT, typename T>
1224 void readList(SmallVectorImpl<T> &list) {
1225 list.clear();
1226 for (unsigned i = 0, e = read(); i != e; ++i)
1227 list.push_back(read<ValueT>());
1228 }
1229
1230 /// Read a list of values from the bytecode buffer. The values may be encoded
1231 /// either as a single element or a range of elements.
1232 void readList(SmallVectorImpl<Type> &list) {
1233 for (unsigned i = 0, e = read(); i != e; ++i) {
1234 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1235 list.push_back(read<Type>());
1236 } else {
1237 TypeRange *values = read<TypeRange *>();
1238 list.append(values->begin(), values->end());
1239 }
1240 }
1241 }
1242 void readList(SmallVectorImpl<Value> &list) {
1243 for (unsigned i = 0, e = read(); i != e; ++i) {
1244 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1245 list.push_back(read<Value>());
1246 } else {
1247 ValueRange *values = read<ValueRange *>();
1248 list.append(values->begin(), values->end());
1249 }
1250 }
1251 }
1252
1253 /// Read a value stored inline as a pointer.
1254 template <typename T>
1255 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1256 readInline() {
1257 const void *pointer;
1258 std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1259 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1260 return T::getFromOpaquePointer(pointer);
1261 }
1262
1263 void skip(size_t skipN) { curCodeIt += skipN; }
1264
1265 /// Jump to a specific successor based on a predicate value.
1266 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1267 /// Jump to a specific successor based on a destination index.
1268 void selectJump(size_t destIndex) {
1269 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1270 }
1271
1272 /// Handle a switch operation with the provided value and cases.
1273 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1274 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1275 LDBG() << "Switch operation:\n * Value: " << value
1276 << "\n * Cases: " << llvm::interleaved(cases);
1277
1278 // Check to see if the attribute value is within the case list. Jump to
1279 // the correct successor index based on the result.
1280 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1281 if (cmp(*it, value))
1282 return selectJump(size_t((it - cases.begin()) + 1));
1283 selectJump(size_t(0));
1284 }
1285
1286 /// Store a pointer to memory.
1287 void storeToMemory(unsigned index, const void *value) {
1288 memory[index] = value;
1289 }
1290
1291 /// Store a value to memory as an opaque pointer.
1292 template <typename T>
1293 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1294 storeToMemory(unsigned index, T value) {
1295 memory[index] = value.getAsOpaquePointer();
1296 }
1297
1298 /// Internal implementation of reading various data types from the bytecode
1299 /// stream.
1300 template <typename T>
1301 const void *readFromMemory() {
1302 size_t index = *curCodeIt++;
1303
1304 // If this type is an SSA value, it can only be stored in non-const memory.
1305 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1306 Value>::value ||
1307 index < memory.size())
1308 return memory[index];
1309
1310 // Otherwise, if this index is not inbounds it is uniqued.
1311 return uniquedMemory[index - memory.size()];
1312 }
1313 template <typename T>
1314 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1315 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1316 }
1317 template <typename T>
1318 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1319 T>
1320 readImpl() {
1321 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1322 }
1323 template <typename T>
1324 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1325 switch (read<PDLValue::Kind>()) {
1326 case PDLValue::Kind::Attribute:
1327 return read<Attribute>();
1328 case PDLValue::Kind::Operation:
1329 return read<Operation *>();
1330 case PDLValue::Kind::Type:
1331 return read<Type>();
1332 case PDLValue::Kind::Value:
1333 return read<Value>();
1334 case PDLValue::Kind::TypeRange:
1335 return read<TypeRange *>();
1336 case PDLValue::Kind::ValueRange:
1337 return read<ValueRange *>();
1338 }
1339 llvm_unreachable("unhandled PDLValue::Kind");
1340 }
1341 template <typename T>
1342 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1343 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1344 "unexpected ByteCode address size");
1345 ByteCodeAddr result;
1346 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1347 curCodeIt += 2;
1348 return result;
1349 }
1350 template <typename T>
1351 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1352 return *curCodeIt++;
1353 }
1354 template <typename T>
1355 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1356 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1357 }
1358
1359 /// Assign the given range to the given memory index. This allocates a new
1360 /// range object if necessary.
1361 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1362 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1363 unsigned rangeIndex) {
1364 // Utility functor used to type-erase the assignment.
1365 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1366 // If the input range is empty, we don't need to allocate anything.
1367 if (range.empty()) {
1368 rangeMemory[rangeIndex] = {};
1369 } else {
1370 // Allocate a buffer for this type range.
1371 llvm::OwningArrayRef<T> storage(llvm::size(range));
1372 llvm::copy(range, storage.begin());
1373
1374 // Assign this to the range slot and use the range as the value for the
1375 // memory index.
1376 allocatedRangeMemory.emplace_back(std::move(storage));
1377 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1378 }
1379 memory[memIndex] = &rangeMemory[rangeIndex];
1380 };
1381
1382 // Dispatch based on the concrete range type.
1383 if constexpr (std::is_same_v<T, Type>) {
1384 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1385 } else if constexpr (std::is_same_v<T, Value>) {
1386 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1387 } else {
1388 llvm_unreachable("unhandled range type");
1389 }
1390 }
1391
1392 /// The underlying bytecode buffer.
1393 const ByteCodeField *curCodeIt;
1394
1395 /// The stack of bytecode positions at which to resume operation.
1396 SmallVector<const ByteCodeField *> resumeCodeIt;
1397
1398 /// The current execution memory.
1399 MutableArrayRef<const void *> memory;
1400 MutableArrayRef<OwningOpRange> opRangeMemory;
1401 MutableArrayRef<TypeRange> typeRangeMemory;
1402 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1403 MutableArrayRef<ValueRange> valueRangeMemory;
1404 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1405
1406 /// The current loop indices.
1407 MutableArrayRef<unsigned> loopIndex;
1408
1409 /// References to ByteCode data necessary for execution.
1410 ArrayRef<const void *> uniquedMemory;
1411 ArrayRef<ByteCodeField> code;
1412 ArrayRef<PatternBenefit> currentPatternBenefits;
1413 ArrayRef<PDLByteCodePattern> patterns;
1414 ArrayRef<PDLConstraintFunction> constraintFunctions;
1415 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1416};
1417} // namespace
1418
1419void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1420 LDBG() << "Executing ApplyConstraint:";
1421 ByteCodeField fun_idx = read();
1422 SmallVector<PDLValue, 16> args;
1423 readList<PDLValue>(args);
1424
1425 LDBG() << " * Arguments: " << llvm::interleaved(args);
1426
1427 ByteCodeField isNegated = read();
1428 LDBG() << " * isNegated: " << isNegated;
1429
1430 ByteCodeField numResults = read();
1431 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1432 ByteCodeRewriteResultList results(numResults);
1433 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1434 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1435 if (succeeded(rewriteResult)) {
1436 LDBG() << " * Constraint succeeded, results: "
1437 << llvm::interleaved(constraintResults);
1438 } else {
1439 LDBG() << " * Constraint failed";
1440 }
1441 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1442 "native PDL rewrite function succeeded but returned "
1443 "unexpected number of results");
1444 processNativeFunResults(results, numResults, rewriteResult);
1445
1446 // Depending on the constraint jump to the proper destination.
1447 selectJump(isNegated != succeeded(rewriteResult));
1448}
1449
1450LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1451 LDBG() << "Executing ApplyRewrite:";
1452 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1453 SmallVector<PDLValue, 16> args;
1454 readList<PDLValue>(args);
1455
1456 LDBG() << " * Arguments: " << llvm::interleaved(args);
1457
1458 // Execute the rewrite function.
1459 ByteCodeField numResults = read();
1460 ByteCodeRewriteResultList results(numResults);
1461 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1462
1463 assert(results.getResults().size() == numResults &&
1464 "native PDL rewrite function returned unexpected number of results");
1465
1466 processNativeFunResults(results, numResults, rewriteResult);
1467
1468 if (failed(rewriteResult)) {
1469 LDBG() << " - Failed";
1470 return failure();
1471 }
1472 return success();
1473}
1474
1475void ByteCodeExecutor::processNativeFunResults(
1476 ByteCodeRewriteResultList &results, unsigned numResults,
1477 LogicalResult &rewriteResult) {
1478 if (failed(rewriteResult)) {
1479 // Skip the according number of values on the buffer on failure and exit
1480 // early as there are no results to process.
1481 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1482 const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1483 if (resultKind == PDLValue::Kind::TypeRange ||
1484 resultKind == PDLValue::Kind::ValueRange) {
1485 skip(2);
1486 } else {
1487 skip(1);
1488 }
1489 }
1490 return;
1491 }
1492
1493 // Store the results in the bytecode memory
1494 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1495 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1496 (void)resultKind;
1497 PDLValue result = results.getResults()[resultIdx];
1498 LDBG() << " * Result: " << result;
1499 assert(result.getKind() == resultKind &&
1500 "native PDL rewrite function returned an unexpected type of "
1501 "result");
1502 // If the result is a range, we need to copy it over to the bytecodes
1503 // range memory.
1504 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1505 unsigned rangeIndex = read();
1506 typeRangeMemory[rangeIndex] = *typeRange;
1507 memory[read()] = &typeRangeMemory[rangeIndex];
1508 } else if (std::optional<ValueRange> valueRange =
1509 result.dyn_cast<ValueRange>()) {
1510 unsigned rangeIndex = read();
1511 valueRangeMemory[rangeIndex] = *valueRange;
1512 memory[read()] = &valueRangeMemory[rangeIndex];
1513 } else {
1514 memory[read()] = result.getAsOpaquePointer();
1515 }
1516 }
1517
1518 // Copy over any underlying storage allocated for result ranges.
1519 for (auto &it : results.getAllocatedTypeRanges())
1520 allocatedTypeRangeMemory.push_back(std::move(it));
1521 for (auto &it : results.getAllocatedValueRanges())
1522 allocatedValueRangeMemory.push_back(std::move(it));
1523}
1524
1525void ByteCodeExecutor::executeAreEqual() {
1526 LDBG() << "Executing AreEqual:";
1527 const void *lhs = read<const void *>();
1528 const void *rhs = read<const void *>();
1529
1530 LDBG() << " * " << lhs << " == " << rhs;
1531 selectJump(lhs == rhs);
1532}
1533
1534void ByteCodeExecutor::executeAreRangesEqual() {
1535 LDBG() << "Executing AreRangesEqual:";
1536 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1537 const void *lhs = read<const void *>();
1538 const void *rhs = read<const void *>();
1539
1540 switch (valueKind) {
1541 case PDLValue::Kind::TypeRange: {
1542 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1543 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1544 LDBG() << " * " << lhs << " == " << rhs;
1545 selectJump(*lhsRange == *rhsRange);
1546 break;
1547 }
1548 case PDLValue::Kind::ValueRange: {
1549 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1550 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1551 LDBG() << " * " << lhs << " == " << rhs;
1552 selectJump(*lhsRange == *rhsRange);
1553 break;
1554 }
1555 default:
1556 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1557 }
1558}
1559
1560void ByteCodeExecutor::executeBranch() {
1561 LDBG() << "Executing Branch";
1562 curCodeIt = &code[read<ByteCodeAddr>()];
1563}
1564
1565void ByteCodeExecutor::executeCheckOperandCount() {
1566 LDBG() << "Executing CheckOperandCount:";
1567 Operation *op = read<Operation *>();
1568 uint32_t expectedCount = read<uint32_t>();
1569 bool compareAtLeast = read();
1570
1571 LDBG() << " * Found: " << op->getNumOperands()
1572 << "\n * Expected: " << expectedCount
1573 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1574 if (compareAtLeast)
1575 selectJump(op->getNumOperands() >= expectedCount);
1576 else
1577 selectJump(op->getNumOperands() == expectedCount);
1578}
1579
1580void ByteCodeExecutor::executeCheckOperationName() {
1581 LDBG() << "Executing CheckOperationName:";
1582 Operation *op = read<Operation *>();
1583 OperationName expectedName = read<OperationName>();
1584
1585 LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
1586 << expectedName << "\"";
1587 selectJump(op->getName() == expectedName);
1588}
1589
1590void ByteCodeExecutor::executeCheckResultCount() {
1591 LDBG() << "Executing CheckResultCount:";
1592 Operation *op = read<Operation *>();
1593 uint32_t expectedCount = read<uint32_t>();
1594 bool compareAtLeast = read();
1595
1596 LDBG() << " * Found: " << op->getNumResults()
1597 << "\n * Expected: " << expectedCount
1598 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1599 if (compareAtLeast)
1600 selectJump(op->getNumResults() >= expectedCount);
1601 else
1602 selectJump(op->getNumResults() == expectedCount);
1603}
1604
1605void ByteCodeExecutor::executeCheckTypes() {
1606 LDBG() << "Executing AreEqual:";
1607 TypeRange *lhs = read<TypeRange *>();
1608 Attribute rhs = read<Attribute>();
1609 LDBG() << " * " << lhs << " == " << rhs;
1610
1611 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1612}
1613
1614void ByteCodeExecutor::executeContinue() {
1615 ByteCodeField level = read();
1616 LDBG() << "Executing Continue\n * Level: " << level;
1617 ++loopIndex[level];
1618 popCodeIt();
1619}
1620
1621void ByteCodeExecutor::executeCreateConstantTypeRange() {
1622 LDBG() << "Executing CreateConstantTypeRange:";
1623 unsigned memIndex = read();
1624 unsigned rangeIndex = read();
1625 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1626
1627 LDBG() << " * Types: " << typesAttr;
1628 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1629 rangeIndex);
1630}
1631
1632void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1633 Location mainRewriteLoc) {
1634 LDBG() << "Executing CreateOperation:";
1635
1636 unsigned memIndex = read();
1637 OperationState state(mainRewriteLoc, read<OperationName>());
1638 readList(state.operands);
1639 for (unsigned i = 0, e = read(); i != e; ++i) {
1640 StringAttr name = read<StringAttr>();
1641 if (Attribute attr = read<Attribute>())
1642 state.addAttribute(name, attr);
1643 }
1644
1645 // Read in the result types. If the "size" is the sentinel value, this
1646 // indicates that the result types should be inferred.
1647 unsigned numResults = read();
1648 if (numResults == kInferTypesMarker) {
1649 InferTypeOpInterface::Concept *inferInterface =
1650 state.name.getInterface<InferTypeOpInterface>();
1651 assert(inferInterface &&
1652 "expected operation to provide InferTypeOpInterface");
1653
1654 // TODO: Handle failure.
1655 if (failed(inferInterface->inferReturnTypes(
1656 state.getContext(), state.location, state.operands,
1657 state.attributes.getDictionary(state.getContext()),
1658 state.getRawProperties(), state.regions, state.types)))
1659 return;
1660 } else {
1661 // Otherwise, this is a fixed number of results.
1662 for (unsigned i = 0; i != numResults; ++i) {
1663 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1664 state.types.push_back(read<Type>());
1665 } else {
1666 TypeRange *resultTypes = read<TypeRange *>();
1667 state.types.append(resultTypes->begin(), resultTypes->end());
1668 }
1669 }
1670 }
1671
1672 Operation *resultOp = rewriter.create(state);
1673 memory[memIndex] = resultOp;
1674
1675 LDBG() << " * Attributes: "
1676 << state.attributes.getDictionary(state.getContext())
1677 << "\n * Operands: " << llvm::interleaved(state.operands)
1678 << "\n * Result Types: " << llvm::interleaved(state.types)
1679 << "\n * Result: " << *resultOp;
1680}
1681
1682template <typename T>
1683void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1684 LDBG() << "Executing CreateDynamic" << type << "Range:";
1685 unsigned memIndex = read();
1686 unsigned rangeIndex = read();
1687 SmallVector<T> values;
1688 readList(values);
1689
1690 LDBG() << " * " << type << "s: " << llvm::interleaved(values);
1691
1692 assignRangeToMemory(values, memIndex, rangeIndex);
1693}
1694
1695void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1696 LDBG() << "Executing EraseOp:";
1697 Operation *op = read<Operation *>();
1698
1699 LDBG() << " * Operation: " << *op;
1700 rewriter.eraseOp(op);
1701}
1702
1703template <typename T, typename Range, PDLValue::Kind kind>
1704void ByteCodeExecutor::executeExtract() {
1705 LDBG() << "Executing Extract" << kind << ":";
1706 Range *range = read<Range *>();
1707 unsigned index = read<uint32_t>();
1708 unsigned memIndex = read();
1709
1710 if (!range) {
1711 memory[memIndex] = nullptr;
1712 return;
1713 }
1714
1715 T result = index < range->size() ? (*range)[index] : T();
1716 LDBG() << " * " << kind << "s(" << range->size() << ")";
1717 LDBG() << " * Index: " << index;
1718 LDBG() << " * Result: " << result;
1719 storeToMemory(memIndex, result);
1720}
1721
1722void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
1723
1724void ByteCodeExecutor::executeForEach() {
1725 LDBG() << "Executing ForEach:";
1726 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1727 unsigned rangeIndex = read();
1728 unsigned memIndex = read();
1729 const void *value = nullptr;
1730
1731 switch (read<PDLValue::Kind>()) {
1732 case PDLValue::Kind::Operation: {
1733 unsigned &index = loopIndex[read()];
1734 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1735 assert(index <= array.size() && "iterated past the end");
1736 if (index < array.size()) {
1737 LDBG() << " * Result: " << array[index];
1738 value = array[index];
1739 break;
1740 }
1741
1742 LDBG() << " * Done";
1743 index = 0;
1744 selectJump(size_t(0));
1745 return;
1746 }
1747 default:
1748 llvm_unreachable("unexpected `ForEach` value kind");
1749 }
1750
1751 // Store the iterate value and the stack address.
1752 memory[memIndex] = value;
1753 pushCodeIt(prevCodeIt);
1754
1755 // Skip over the successor (we will enter the body of the loop).
1756 read<ByteCodeAddr>();
1757}
1758
1759void ByteCodeExecutor::executeGetAttribute() {
1760 LDBG() << "Executing GetAttribute:";
1761 unsigned memIndex = read();
1762 Operation *op = read<Operation *>();
1763 StringAttr attrName = read<StringAttr>();
1764 Attribute attr = op->getAttr(attrName);
1765
1766 LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
1767 << "\n * Result: " << attr;
1768 memory[memIndex] = attr.getAsOpaquePointer();
1769}
1770
1771void ByteCodeExecutor::executeGetAttributeType() {
1772 LDBG() << "Executing GetAttributeType:";
1773 unsigned memIndex = read();
1774 Attribute attr = read<Attribute>();
1775 Type type;
1776 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1777 type = typedAttr.getType();
1778
1779 LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
1780 memory[memIndex] = type.getAsOpaquePointer();
1781}
1782
1783void ByteCodeExecutor::executeGetDefiningOp() {
1784 LDBG() << "Executing GetDefiningOp:";
1785 unsigned memIndex = read();
1786 Operation *op = nullptr;
1787 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1788 Value value = read<Value>();
1789 if (value)
1790 op = value.getDefiningOp();
1791 LDBG() << " * Value: " << value;
1792 } else {
1793 ValueRange *values = read<ValueRange *>();
1794 if (values && !values->empty()) {
1795 op = values->front().getDefiningOp();
1796 }
1797 LDBG() << " * Values: " << values;
1798 }
1799
1800 LDBG() << " * Result: " << op;
1801 memory[memIndex] = op;
1802}
1803
1804void ByteCodeExecutor::executeGetOperand(unsigned index) {
1805 Operation *op = read<Operation *>();
1806 unsigned memIndex = read();
1807 Value operand =
1808 index < op->getNumOperands() ? op->getOperand(index) : Value();
1809
1810 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1811 << "\n * Result: " << operand;
1812 memory[memIndex] = operand.getAsOpaquePointer();
1813}
1814
1815/// This function is the internal implementation of `GetResults` and
1816/// `GetOperands` that provides support for extracting a value range from the
1817/// given operation.
1818template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1819static void *
1820executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1821 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1822 MutableArrayRef<ValueRange> valueRangeMemory) {
1823 // Check for the sentinel index that signals that all values should be
1824 // returned.
1825 if (index == std::numeric_limits<uint32_t>::max()) {
1826 LDBG() << " * Getting all values";
1827 // `values` is already the full value range.
1828
1829 // Otherwise, check to see if this operation uses AttrSizedSegments.
1830 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1831 LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
1832
1833 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1834 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1835 return nullptr;
1836
1837 ArrayRef<int32_t> segments = segmentAttr;
1838 unsigned startIndex = llvm::sum_of(segments.take_front(index));
1839 values = values.slice(startIndex, *std::next(segments.begin(), index));
1840
1841 LDBG() << " * Extracting range[" << startIndex << ", "
1842 << *std::next(segments.begin(), index) << "]";
1843
1844 // Otherwise, assume this is the last operand group of the operation.
1845 // FIXME: We currently don't support operations with
1846 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1847 // have a way to detect it's presence.
1848 } else if (values.size() >= index) {
1849 LDBG() << " * Treating values as trailing variadic range";
1850 values = values.drop_front(index);
1851
1852 // If we couldn't detect a way to compute the values, bail out.
1853 } else {
1854 return nullptr;
1855 }
1856
1857 // If the range index is valid, we are returning a range.
1858 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1859 valueRangeMemory[rangeIndex] = values;
1860 return &valueRangeMemory[rangeIndex];
1861 }
1862
1863 // If a range index wasn't provided, the range is required to be non-variadic.
1864 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1865}
1866
1867void ByteCodeExecutor::executeGetOperands() {
1868 LDBG() << "Executing GetOperands:";
1869 unsigned index = read<uint32_t>();
1870 Operation *op = read<Operation *>();
1871 ByteCodeField rangeIndex = read();
1872
1874 op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1875 valueRangeMemory);
1876 if (!result)
1877 LDBG() << " * Invalid operand range";
1878 memory[read()] = result;
1879}
1880
1881void ByteCodeExecutor::executeGetResult(unsigned index) {
1882 Operation *op = read<Operation *>();
1883 unsigned memIndex = read();
1884 OpResult result =
1885 index < op->getNumResults() ? op->getResult(index) : OpResult();
1886
1887 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1888 << "\n * Result: " << result;
1889 memory[memIndex] = result.getAsOpaquePointer();
1890}
1891
1892void ByteCodeExecutor::executeGetResults() {
1893 LDBG() << "Executing GetResults:";
1894 unsigned index = read<uint32_t>();
1895 Operation *op = read<Operation *>();
1896 ByteCodeField rangeIndex = read();
1897
1899 op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1900 valueRangeMemory);
1901 if (!result)
1902 LDBG() << " * Invalid result range";
1903 memory[read()] = result;
1904}
1905
1906void ByteCodeExecutor::executeGetUsers() {
1907 LDBG() << "Executing GetUsers:";
1908 unsigned memIndex = read();
1909 unsigned rangeIndex = read();
1910 OwningOpRange &range = opRangeMemory[rangeIndex];
1911 memory[memIndex] = &range;
1912
1913 range = OwningOpRange();
1914 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1915 // Read the value.
1916 Value value = read<Value>();
1917 if (!value)
1918 return;
1919 LDBG() << " * Value: " << value;
1920
1921 // Extract the users of a single value.
1922 range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1923 llvm::copy(value.getUsers(), range.begin());
1924 } else {
1925 // Read a range of values.
1926 ValueRange *values = read<ValueRange *>();
1927 if (!values)
1928 return;
1929 LDBG() << " * Values (" << values->size()
1930 << "): " << llvm::interleaved(*values);
1931
1932 // Extract all the users of a range of values.
1933 SmallVector<Operation *> users;
1934 for (Value value : *values)
1935 users.append(value.user_begin(), value.user_end());
1936 range = OwningOpRange(users.size());
1937 llvm::copy(users, range.begin());
1938 }
1939
1940 LDBG() << " * Result: " << range.size() << " operations";
1941}
1942
1943void ByteCodeExecutor::executeGetValueType() {
1944 LDBG() << "Executing GetValueType:";
1945 unsigned memIndex = read();
1946 Value value = read<Value>();
1947 Type type = value ? value.getType() : Type();
1948
1949 LDBG() << " * Value: " << value << "\n * Result: " << type;
1950 memory[memIndex] = type.getAsOpaquePointer();
1951}
1952
1953void ByteCodeExecutor::executeGetValueRangeTypes() {
1954 LDBG() << "Executing GetValueRangeTypes:";
1955 unsigned memIndex = read();
1956 unsigned rangeIndex = read();
1957 ValueRange *values = read<ValueRange *>();
1958 if (!values) {
1959 LDBG() << " * Values: <NULL>";
1960 memory[memIndex] = nullptr;
1961 return;
1962 }
1963
1964 LDBG() << " * Values (" << values->size()
1965 << "): " << llvm::interleaved(*values)
1966 << "\n * Result: " << llvm::interleaved(values->getType());
1967 typeRangeMemory[rangeIndex] = values->getType();
1968 memory[memIndex] = &typeRangeMemory[rangeIndex];
1969}
1970
1971void ByteCodeExecutor::executeIsNotNull() {
1972 LDBG() << "Executing IsNotNull:";
1973 const void *value = read<const void *>();
1974
1975 LDBG() << " * Value: " << value;
1976 selectJump(value != nullptr);
1977}
1978
1979void ByteCodeExecutor::executeRecordMatch(
1980 PatternRewriter &rewriter,
1981 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1982 LDBG() << "Executing RecordMatch:";
1983 unsigned patternIndex = read();
1984 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1985 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1986
1987 // If the benefit of the pattern is impossible, skip the processing of the
1988 // rest of the pattern.
1989 if (benefit.isImpossibleToMatch()) {
1990 LDBG() << " * Benefit: Impossible To Match";
1991 curCodeIt = dest;
1992 return;
1993 }
1994
1995 // Create a fused location containing the locations of each of the
1996 // operations used in the match. This will be used as the location for
1997 // created operations during the rewrite that don't already have an
1998 // explicit location set.
1999 unsigned numMatchLocs = read();
2000 SmallVector<Location, 4> matchLocs;
2001 matchLocs.reserve(numMatchLocs);
2002 for (unsigned i = 0; i != numMatchLocs; ++i)
2003 matchLocs.push_back(read<Operation *>()->getLoc());
2004 Location matchLoc = rewriter.getFusedLoc(matchLocs);
2005
2006 LDBG() << " * Benefit: " << benefit.getBenefit();
2007 LDBG() << " * Location: " << matchLoc;
2008 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2009 PDLByteCode::MatchResult &match = matches.back();
2010
2011 // Record all of the inputs to the match. If any of the inputs are ranges, we
2012 // will also need to remap the range pointer to memory stored in the match
2013 // state.
2014 unsigned numInputs = read();
2015 match.values.reserve(numInputs);
2016 match.typeRangeValues.reserve(numInputs);
2017 match.valueRangeValues.reserve(numInputs);
2018 for (unsigned i = 0; i < numInputs; ++i) {
2019 switch (read<PDLValue::Kind>()) {
2020 case PDLValue::Kind::TypeRange:
2021 match.typeRangeValues.push_back(*read<TypeRange *>());
2022 match.values.push_back(&match.typeRangeValues.back());
2023 break;
2024 case PDLValue::Kind::ValueRange:
2025 match.valueRangeValues.push_back(*read<ValueRange *>());
2026 match.values.push_back(&match.valueRangeValues.back());
2027 break;
2028 default:
2029 match.values.push_back(read<const void *>());
2030 break;
2031 }
2032 }
2033 curCodeIt = dest;
2034}
2035
2036void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2037 LDBG() << "Executing ReplaceOp:";
2038 Operation *op = read<Operation *>();
2039 SmallVector<Value, 16> args;
2040 readList(args);
2041
2042 LDBG() << " * Operation: " << *op
2043 << "\n * Values: " << llvm::interleaved(args);
2044 rewriter.replaceOp(op, args);
2045}
2046
2047void ByteCodeExecutor::executeSwitchAttribute() {
2048 LDBG() << "Executing SwitchAttribute:";
2049 Attribute value = read<Attribute>();
2050 ArrayAttr cases = read<ArrayAttr>();
2051 handleSwitch(value, cases);
2052}
2053
2054void ByteCodeExecutor::executeSwitchOperandCount() {
2055 LDBG() << "Executing SwitchOperandCount:";
2056 Operation *op = read<Operation *>();
2057 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2058
2059 LDBG() << " * Operation: " << *op;
2060 handleSwitch(op->getNumOperands(), cases);
2061}
2062
2063void ByteCodeExecutor::executeSwitchOperationName() {
2064 LDBG() << "Executing SwitchOperationName:";
2065 OperationName value = read<Operation *>()->getName();
2066 size_t caseCount = read();
2067
2068 // The operation names are stored in-line, so to print them out for
2069 // debugging purposes we need to read the array before executing the
2070 // switch so that we can display all of the possible values.
2071 LLVM_DEBUG({
2072 const ByteCodeField *prevCodeIt = curCodeIt;
2073 LDBG() << " * Value: " << value << "\n * Cases: "
2074 << llvm::interleaved(
2075 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
2076 return read<OperationName>();
2077 }));
2078 curCodeIt = prevCodeIt;
2079 });
2080
2081 // Try to find the switch value within any of the cases.
2082 for (size_t i = 0; i != caseCount; ++i) {
2083 if (read<OperationName>() == value) {
2084 curCodeIt += (caseCount - i - 1);
2085 return selectJump(i + 1);
2086 }
2087 }
2088 selectJump(size_t(0));
2089}
2090
2091void ByteCodeExecutor::executeSwitchResultCount() {
2092 LDBG() << "Executing SwitchResultCount:";
2093 Operation *op = read<Operation *>();
2094 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2095
2096 LDBG() << " * Operation: " << *op;
2097 handleSwitch(op->getNumResults(), cases);
2098}
2099
2100void ByteCodeExecutor::executeSwitchType() {
2101 LDBG() << "Executing SwitchType:";
2102 Type value = read<Type>();
2103 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2104 handleSwitch(value, cases);
2105}
2106
2107void ByteCodeExecutor::executeSwitchTypes() {
2108 LDBG() << "Executing SwitchTypes:";
2109 TypeRange *value = read<TypeRange *>();
2110 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2111 if (!value) {
2112 LDBG() << "Types: <NULL>";
2113 return selectJump(size_t(0));
2114 }
2115 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2116 return value == caseValue.getAsValueRange<TypeAttr>();
2117 });
2118}
2119
2120LogicalResult
2121ByteCodeExecutor::execute(PatternRewriter &rewriter,
2122 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2123 std::optional<Location> mainRewriteLoc) {
2124 while (true) {
2125 // Print the location of the operation being executed.
2126 LDBG() << readInline<Location>();
2127
2128 OpCode opCode = static_cast<OpCode>(read());
2129 switch (opCode) {
2130 case ApplyConstraint:
2131 executeApplyConstraint(rewriter);
2132 break;
2133 case ApplyRewrite:
2134 if (failed(executeApplyRewrite(rewriter)))
2135 return failure();
2136 break;
2137 case AreEqual:
2138 executeAreEqual();
2139 break;
2140 case AreRangesEqual:
2141 executeAreRangesEqual();
2142 break;
2143 case Branch:
2144 executeBranch();
2145 break;
2146 case CheckOperandCount:
2147 executeCheckOperandCount();
2148 break;
2149 case CheckOperationName:
2150 executeCheckOperationName();
2151 break;
2152 case CheckResultCount:
2153 executeCheckResultCount();
2154 break;
2155 case CheckTypes:
2156 executeCheckTypes();
2157 break;
2158 case Continue:
2159 executeContinue();
2160 break;
2161 case CreateConstantTypeRange:
2162 executeCreateConstantTypeRange();
2163 break;
2164 case CreateOperation:
2165 executeCreateOperation(rewriter, *mainRewriteLoc);
2166 break;
2167 case CreateDynamicTypeRange:
2168 executeDynamicCreateRange<Type>("Type");
2169 break;
2170 case CreateDynamicValueRange:
2171 executeDynamicCreateRange<Value>("Value");
2172 break;
2173 case EraseOp:
2174 executeEraseOp(rewriter);
2175 break;
2176 case ExtractOp:
2177 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2178 break;
2179 case ExtractType:
2180 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2181 break;
2182 case ExtractValue:
2183 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2184 break;
2185 case Finalize:
2186 executeFinalize();
2187 LDBG() << "";
2188 return success();
2189 case ForEach:
2190 executeForEach();
2191 break;
2192 case GetAttribute:
2193 executeGetAttribute();
2194 break;
2195 case GetAttributeType:
2196 executeGetAttributeType();
2197 break;
2198 case GetDefiningOp:
2199 executeGetDefiningOp();
2200 break;
2201 case GetOperand0:
2202 case GetOperand1:
2203 case GetOperand2:
2204 case GetOperand3: {
2205 unsigned index = opCode - GetOperand0;
2206 LDBG() << "Executing GetOperand" << index << ":";
2207 executeGetOperand(index);
2208 break;
2209 }
2210 case GetOperandN:
2211 LDBG() << "Executing GetOperandN:";
2212 executeGetOperand(read<uint32_t>());
2213 break;
2214 case GetOperands:
2215 executeGetOperands();
2216 break;
2217 case GetResult0:
2218 case GetResult1:
2219 case GetResult2:
2220 case GetResult3: {
2221 unsigned index = opCode - GetResult0;
2222 LDBG() << "Executing GetResult" << index << ":";
2223 executeGetResult(index);
2224 break;
2225 }
2226 case GetResultN:
2227 LDBG() << "Executing GetResultN:";
2228 executeGetResult(read<uint32_t>());
2229 break;
2230 case GetResults:
2231 executeGetResults();
2232 break;
2233 case GetUsers:
2234 executeGetUsers();
2235 break;
2236 case GetValueType:
2237 executeGetValueType();
2238 break;
2239 case GetValueRangeTypes:
2240 executeGetValueRangeTypes();
2241 break;
2242 case IsNotNull:
2243 executeIsNotNull();
2244 break;
2245 case RecordMatch:
2246 assert(matches &&
2247 "expected matches to be provided when executing the matcher");
2248 executeRecordMatch(rewriter, *matches);
2249 break;
2250 case ReplaceOp:
2251 executeReplaceOp(rewriter);
2252 break;
2253 case SwitchAttribute:
2254 executeSwitchAttribute();
2255 break;
2256 case SwitchOperandCount:
2257 executeSwitchOperandCount();
2258 break;
2259 case SwitchOperationName:
2260 executeSwitchOperationName();
2261 break;
2262 case SwitchResultCount:
2263 executeSwitchResultCount();
2264 break;
2265 case SwitchType:
2266 executeSwitchType();
2267 break;
2268 case SwitchTypes:
2269 executeSwitchTypes();
2270 break;
2271 }
2272 LDBG() << "";
2273 }
2274}
2275
2276void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2277 SmallVectorImpl<MatchResult> &matches,
2278 PDLByteCodeMutableState &state) const {
2279 // The first memory slot is always the root operation.
2280 state.memory[0] = op;
2281
2282 // The matcher function always starts at code address 0.
2283 ByteCodeExecutor executor(
2284 matcherByteCode.data(), state.memory, state.opRangeMemory,
2285 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2286 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2287 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2288 constraintFunctions, rewriteFunctions);
2289 LogicalResult executeResult = executor.execute(rewriter, &matches);
2290 (void)executeResult;
2291 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2292
2293 // Order the found matches by benefit.
2294 llvm::stable_sort(matches,
2295 [](const MatchResult &lhs, const MatchResult &rhs) {
2296 return lhs.benefit > rhs.benefit;
2297 });
2298}
2299
2300LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2301 const MatchResult &match,
2302 PDLByteCodeMutableState &state) const {
2303 auto *configSet = match.pattern->getConfigSet();
2304 if (configSet)
2305 configSet->notifyRewriteBegin(rewriter);
2306
2307 // The arguments of the rewrite function are stored at the start of the
2308 // memory buffer.
2309 llvm::copy(match.values, state.memory.begin());
2310
2311 ByteCodeExecutor executor(
2312 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2313 state.opRangeMemory, state.typeRangeMemory,
2314 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2315 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2316 rewriterByteCode, state.currentPatternBenefits, patterns,
2317 constraintFunctions, rewriteFunctions);
2318 LogicalResult result =
2319 executor.execute(rewriter, /*matches=*/nullptr, match.location);
2320
2321 if (configSet)
2322 configSet->notifyRewriteEnd(rewriter);
2323
2324 // If the rewrite failed, check if the pattern rewriter can recover. If it
2325 // can, we can signal to the pattern applicator to keep trying patterns. If it
2326 // doesn't, we need to bail. Bailing here should be fine, given that we have
2327 // no means to propagate such a failure to the user, and it also indicates a
2328 // bug in the user code (i.e. failable rewrites should not be used with
2329 // pattern rewriters that don't support it).
2330 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2331 LDBG() << " and rollback is not supported - aborting";
2332 llvm::report_fatal_error(
2333 "Native PDL Rewrite failed, but the pattern "
2334 "rewriter doesn't support recovery. Failable pattern rewrites should "
2335 "not be used with pattern rewriters that do not support them.");
2336 }
2337 return result;
2338}
for(Operation *op :ops)
return success()
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition ByteCode.cpp:175
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
lhs
ArrayAttr()
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition Attributes.h:73
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
BlockArgListType getArguments()
Definition Block.h:87
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition Liveness.cpp:375
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Types.h:164
type_range getType() const
Type getType() const
Return the type of this value.
Definition Value.h:105
user_iterator user_begin() const
Definition Value.h:216
user_iterator user_end() const
Definition Value.h:217
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition ByteCode.h:249
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition ByteCode.h:248
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult size