MLIR 23.0.0git
ByteCode.cpp
Go to the documentation of this file.
1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
17#include "mlir/IR/BuiltinOps.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/Format.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/InterleavedRange.h"
27#include <numeric>
28#include <optional>
29
30#define DEBUG_TYPE "pdl-bytecode"
31
32using namespace mlir;
33using namespace mlir::detail;
34
35//===----------------------------------------------------------------------===//
36// PDLByteCodePattern
37//===----------------------------------------------------------------------===//
38
39PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
42 PatternBenefit benefit = matchOp.getBenefit();
43 MLIRContext *ctx = matchOp.getContext();
44
45 // Collect the set of generated operations.
46 SmallVector<StringRef, 8> generatedOps;
47 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
48 generatedOps =
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50
51 // Check to see if this is pattern matches a specific operation type.
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
53 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
54 generatedOps);
55 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
56 benefit, ctx, generatedOps);
57}
58
59//===----------------------------------------------------------------------===//
60// PDLByteCodeMutableState
61//===----------------------------------------------------------------------===//
62
63/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
64/// to the position of the pattern within the range returned by
65/// `PDLByteCode::getPatterns`.
66void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
67 PatternBenefit benefit) {
68 currentPatternBenefits[patternIndex] = benefit;
69}
70
71/// Cleanup any allocated state after a full match/rewrite has been completed.
72/// This method should be called irregardless of whether the match+rewrite was a
73/// success or not.
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
77}
78
79//===----------------------------------------------------------------------===//
80// Bytecode OpCodes
81//===----------------------------------------------------------------------===//
82
83namespace {
84enum OpCode : ByteCodeField {
85 /// Apply an externally registered constraint.
86 ApplyConstraint,
87 /// Apply an externally registered rewrite.
88 ApplyRewrite,
89 /// Check if two generic values are equal.
90 AreEqual,
91 /// Check if two ranges are equal.
92 AreRangesEqual,
93 /// Unconditional branch.
94 Branch,
95 /// Compare the operand count of an operation with a constant.
96 CheckOperandCount,
97 /// Compare the name of an operation with a constant.
98 CheckOperationName,
99 /// Compare the result count of an operation with a constant.
100 CheckResultCount,
101 /// Compare a range of types to a constant range of types.
102 CheckTypes,
103 /// Continue to the next iteration of a loop.
104 Continue,
105 /// Create a type range from a list of constant types.
106 CreateConstantTypeRange,
107 /// Create an operation.
108 CreateOperation,
109 /// Create a type range from a list of dynamic types.
110 CreateDynamicTypeRange,
111 /// Create a value range.
112 CreateDynamicValueRange,
113 /// Erase an operation.
114 EraseOp,
115 /// Extract the op from a range at the specified index.
116 ExtractOp,
117 /// Extract the type from a range at the specified index.
118 ExtractType,
119 /// Extract the value from a range at the specified index.
120 ExtractValue,
121 /// Terminate a matcher or rewrite sequence.
122 Finalize,
123 /// Iterate over a range of values.
124 ForEach,
125 /// Get a specific attribute of an operation.
126 GetAttribute,
127 /// Get the type of an attribute.
128 GetAttributeType,
129 /// Get the defining operation of a value.
130 GetDefiningOp,
131 /// Get a specific operand of an operation.
132 GetOperand0,
133 GetOperand1,
134 GetOperand2,
135 GetOperand3,
136 GetOperandN,
137 /// Get a specific operand group of an operation.
138 GetOperands,
139 /// Get a specific result of an operation.
140 GetResult0,
141 GetResult1,
142 GetResult2,
143 GetResult3,
144 GetResultN,
145 /// Get a specific result group of an operation.
146 GetResults,
147 /// Get the users of a value or a range of values.
148 GetUsers,
149 /// Get the type of a value.
150 GetValueType,
151 /// Get the types of a value range.
152 GetValueRangeTypes,
153 /// Check if a generic value is not null.
154 IsNotNull,
155 /// Record a successful pattern match.
156 RecordMatch,
157 /// Replace an operation.
158 ReplaceOp,
159 /// Compare an attribute with a set of constants.
160 SwitchAttribute,
161 /// Compare the operand count of an operation with a set of constants.
162 SwitchOperandCount,
163 /// Compare the name of an operation with a set of constants.
164 SwitchOperationName,
165 /// Compare the result count of an operation with a set of constants.
166 SwitchResultCount,
167 /// Compare a type with a set of constants.
168 SwitchType,
169 /// Compare a range of types with a set of constants.
170 SwitchTypes,
171};
172} // namespace
173
174/// A marker used to indicate if an operation should infer types.
175static constexpr ByteCodeField kInferTypesMarker =
176 std::numeric_limits<ByteCodeField>::max();
177
178//===----------------------------------------------------------------------===//
179// ByteCode Generation
180//===----------------------------------------------------------------------===//
181
182//===----------------------------------------------------------------------===//
183// Generator
184//===----------------------------------------------------------------------===//
185
186namespace {
187struct ByteCodeLiveRange;
188struct ByteCodeWriter;
189
190/// Check if the given class `T` can be converted to an opaque pointer.
191template <typename T, typename... Args>
192using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
193
194/// This class represents the main generator for the pattern bytecode.
195class Generator {
196public:
197 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
198 SmallVectorImpl<ByteCodeField> &matcherByteCode,
199 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
210 rewriterByteCode(rewriterByteCode), patterns(patterns),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
216 for (const auto &it : llvm::enumerate(constraintFns))
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
218 for (const auto &it : llvm::enumerate(rewriteFns))
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 }
221
222 /// Generate the bytecode for the given PDL interpreter module.
223 void generate(ModuleOp module);
224
225 /// Return the memory index to use for the given value.
226 ByteCodeField &getMemIndex(Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
230 }
231
232 /// Return the range memory index used to store the given range value.
233 ByteCodeField &getRangeStorageIndex(Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
237 }
238
239 /// Return an index to use when referring to the given data that is uniqued in
240 /// the MLIR context.
241 template <typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
243 getMemIndex(T val) {
244 const void *opaqueVal = val.getAsOpaquePointer();
245
246 // Get or insert a reference to this value.
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
249 if (it.second)
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
252 }
253
254private:
255 /// Allocate memory indices for the results of operations within the matcher
256 /// and rewriters.
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
259
260 /// Generate the bytecode for the given operation.
261 void generate(Region *region, ByteCodeWriter &writer);
262 void generate(Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
301
302 /// Mapping from value to its corresponding memory index.
303 DenseMap<Value, ByteCodeField> valueToMemIndex;
304
305 /// Mapping from a range value to its corresponding range storage index.
306 DenseMap<Value, ByteCodeField> valueToRangeIndex;
307
308 /// Mapping from the name of an externally registered rewrite to its index in
309 /// the bytecode registry.
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311
312 /// Mapping from the name of an externally registered constraint to its index
313 /// in the bytecode registry.
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315
316 /// Mapping from rewriter function name to the bytecode address of the
317 /// rewriter function in byte.
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
319
320 /// Mapping from a uniqued storage object to its memory index within
321 /// `uniquedData`.
322 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
323
324 /// The current level of the foreach loop.
325 ByteCodeField curLoopLevel = 0;
326
327 /// The current MLIR context.
328 MLIRContext *ctx;
329
330 /// Mapping from block to its address.
332
333 /// Data of the ByteCode class to be populated.
334 std::vector<const void *> &uniquedData;
335 SmallVectorImpl<ByteCodeField> &matcherByteCode;
336 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
337 SmallVectorImpl<PDLByteCodePattern> &patterns;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
343
344 /// A map of pattern configurations.
346};
347
348/// This class provides utilities for writing a bytecode stream.
349struct ByteCodeWriter {
350 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
351 : bytecode(bytecode), generator(generator) {}
352
353 /// Append a field to the bytecode.
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
356
357 /// Append an address to the bytecode.
358 void append(ByteCodeAddr field) {
359 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
361
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
365 }
366
367 /// Append a single successor to the bytecode, the exact address will need to
368 /// be resolved later.
369 void append(Block *successor) {
370 // Add back a reference to the successor so that the address can be resolved
371 // later.
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
374 }
375
376 /// Append a successor range to the bytecode, the exact address will need to
377 /// be resolved later.
378 void append(SuccessorRange successors) {
379 for (Block *successor : successors)
380 append(successor);
381 }
382
383 /// Append a range of values that will be read as generic PDLValues.
384 void appendPDLValueList(OperandRange values) {
385 bytecode.push_back(values.size());
386 for (Value value : values)
387 appendPDLValue(value);
388 }
389
390 /// Append a value as a PDLValue.
391 void appendPDLValue(Value value) {
392 appendPDLValueKind(value);
393 append(value);
394 }
395
396 /// Append the PDLValue::Kind of the given value.
397 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
398
399 /// Append the PDLValue::Kind of the given type.
400 void appendPDLValueKind(Type type) {
401 PDLValue::Kind kind =
403 .Case<pdl::AttributeType>(
404 [](Type) { return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](Type) { return PDLValue::Kind::Operation; })
407 .Case([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
411 })
412 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
414 bytecode.push_back(static_cast<ByteCodeField>(kind));
415 }
416
417 /// Append a value that will be stored in a memory slot and not inline within
418 /// the bytecode.
419 template <typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
422 append(T value) {
423 bytecode.push_back(generator.getMemIndex(value));
424 }
425
426 /// Append a range of values.
427 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
429 append(T range) {
430 bytecode.push_back(llvm::size(range));
431 for (auto it : range)
432 append(it);
433 }
434
435 /// Append a variadic number of fields to the bytecode.
436 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
438 append(field);
439 append(field2, fields...);
440 }
441
442 /// Appends a value as a pointer, stored inline within the bytecode.
443 template <typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer, sizeof(const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
451 }
452
453 /// Successor references in the bytecode that have yet to be resolved.
454 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
455
456 /// The underlying bytecode buffer.
457 SmallVectorImpl<ByteCodeField> &bytecode;
458
459 /// The main generator producing PDL.
460 Generator &generator;
461};
462
463/// This class represents a live range of PDL Interpreter values, containing
464/// information about when values are live within a match/rewrite.
465struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
468
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470
471 /// Union this live range with the one provided.
472 void unionWith(const ByteCodeLiveRange &rhs) {
473 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
474 ++it)
475 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
476 }
477
478 /// Returns true if this range overlaps with the one provided.
479 bool overlaps(const ByteCodeLiveRange &rhs) const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
481 .valid();
482 }
483
484 /// A map representing the ranges of the match/rewrite that a value is live in
485 /// the interpreter.
486 ///
487 /// We use std::unique_ptr here, because IntervalMap does not provide a
488 /// correct copy or move constructor. We can eliminate the pointer once
489 /// https://reviews.llvm.org/D113240 lands.
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491
492 /// The operation range storage index for this range.
493 std::optional<unsigned> opRangeIndex;
494
495 /// The type range storage index for this range.
496 std::optional<unsigned> typeRangeIndex;
497
498 /// The value range storage index for this range.
499 std::optional<unsigned> valueRangeIndex;
500};
501} // namespace
502
503void Generator::generate(ModuleOp module) {
504 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
509
510 // Allocate memory indices for the results of operations within the matcher
511 // and rewriters.
512 allocateMemoryIndices(matcherFunc, rewriterModule);
513
514 // Generate code for the rewriter functions.
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
516 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
520 }
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
523
524 // Generate code for the matcher function.
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527
528 // Resolve successor references in the matcher.
529 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
533 }
534}
535
536void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
538 // Rewriters use simplistic allocation scheme that simply assigns an index to
539 // each result.
540 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
550 }
551 };
552
553 for (BlockArgument arg : rewriterFunc.getArguments())
554 processRewriterValue(arg);
555 rewriterFunc.getBody().walk([&](Operation *op) {
556 for (Value result : op->getResults())
557 processRewriterValue(result);
558 });
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
565 }
566
567 // The matcher function uses a more sophisticated numbering that tries to
568 // minimize the number of memory indices assigned. This is done by determining
569 // a live range of the values within the matcher, then the allocation is just
570 // finding the minimal number of overlapping live ranges. This is essentially
571 // a simplified form of register allocation where we don't necessarily have a
572 // limited number of registers, but we still want to minimize the number used.
573 DenseMap<Operation *, unsigned> opToFirstIndex;
575
576 // A custom walk that marks the first and the last index of each operation.
577 // The entry marks the beginning of the liveness range for this operation,
578 // followed by nested operations, followed by the end of the liveness range.
579 unsigned index = 0;
580 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
581 opToFirstIndex.try_emplace(op, index++);
582 for (Region &region : op->getRegions())
583 for (Block &block : region.getBlocks())
584 for (Operation &nested : block)
585 walk(&nested);
586 opToLastIndex.try_emplace(op, index++);
587 };
588 walk(matcherFunc);
589
590 // Liveness info for each of the defs within the matcher.
591 ByteCodeLiveRange::Allocator allocator;
593
594 // Assign the root operation being matched to slot 0.
595 BlockArgument rootOpArg = matcherFunc.getArgument(0);
596 valueToMemIndex[rootOpArg] = 0;
597
598 // Walk each of the blocks, computing the def interval that the value is used.
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](Block *block) {
601 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602 assert(info && "expected liveness info for block");
603 auto processValue = [&](Value value, Operation *firstUseOrDef) {
604 // We don't need to process the root op argument, this value is always
605 // assigned to the first memory slot.
606 if (value == rootOpArg)
607 return;
608
609 // Set indices for the range of this block that the value is used.
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
613 opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
614 /*dummyValue*/ 0);
615
616 // Check to see if this value is a range type.
617 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
625 }
626 };
627
628 // Process the live-ins of this block.
629 for (Value liveIn : info->in()) {
630 // Only process the value if it has been defined in the current region.
631 // Other values that span across pdl_interp.foreach will be added higher
632 // up. This ensures that the we keep them alive for the entire duration
633 // of the loop.
634 if (liveIn.getParentRegion() == block->getParent())
635 processValue(liveIn, &block->front());
636 }
637
638 // Process the block arguments for the entry block (those are not live-in).
639 if (block->isEntryBlock()) {
640 for (Value argument : block->getArguments())
641 processValue(argument, &block->front());
642 }
643
644 // Process any new defs within this block.
645 for (Operation &op : *block)
646 for (Value result : op.getResults())
647 processValue(result, &op);
648 });
649
650 // Greedily allocate memory slots using the computed def live ranges.
651 std::vector<ByteCodeLiveRange> allocatedIndices;
652
653 // The number of memory indices currently allocated (and its next value).
654 // Recall that the root gets allocated memory index 0.
655 ByteCodeField numIndices = 1;
656
657 // The number of memory ranges of various types (and their next values).
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
659
660 for (auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
663
664 // Try to allocate to an existing index.
665 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
670
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 } else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 } else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
683 }
684 break;
685 }
686 }
687
688 // If no existing index could be used, add a new one.
689 if (memIndex == 0) {
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
693
694 // Allocate an index for op/type/value ranges.
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 } else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 } else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
704 }
705
706 memIndex = allocatedIndices.size();
707 ++numIndices;
708 }
709 }
710
711 // Print the index usage and ensure that we did not run out of index space.
712 LDBG() << "Allocated " << allocatedIndices.size() << " indices "
713 << "(down from initial " << valueDefRanges.size() << ").";
714 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715 "Ran out of memory for allocated indices");
716
717 // Update the max number of indices.
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
726}
727
728void Generator::generate(Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (Block *block : rpot) {
731 // Keep track of where this block begins within the matcher function.
732 blockToAddr.try_emplace(block, matcherByteCode.size());
733 for (Operation &op : *block)
734 generate(&op, writer);
735 }
736}
737
738void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739 LDBG() << "Generating bytecode for operation: " << op->getName();
740 LLVM_DEBUG({
741 // The following list must contain all the operations that do not
742 // produce any bytecode.
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->getLoc());
745 });
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](auto interpOp) { this->generate(interpOp, writer); })
767 .DefaultUnreachable("unknown `pdl_interp` operation");
768}
769
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 // Constraints that should return a value have to be registered as rewrites.
773 // If a constraint and a rewrite of similar name are registered the
774 // constraint takes precedence
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
778 ResultRange results = op.getResults();
779 writer.append(ByteCodeField(results.size()));
780 for (Value result : results) {
781 // We record the expected kind of the result, so that we can provide extra
782 // verification of the native rewrite function and handle the failure case
783 // of constraints accordingly.
784 writer.appendPDLValueKind(result);
785
786 // Range results also need to append the range storage index.
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
790 }
791 writer.append(op.getSuccessors());
792}
793void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
799
800 ResultRange results = op.getResults();
801 writer.append(ByteCodeField(results.size()));
802 for (Value result : results) {
803 // We record the expected kind of the result, so that we
804 // can provide extra verification of the native rewrite function.
805 writer.appendPDLValueKind(result);
806
807 // Range results also need to append the range storage index.
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
811 }
812}
813void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819 return;
820 }
821
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823}
824void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826}
827void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830 op.getSuccessors());
831}
832void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField>(op.getCompareAtLeast()),
836 op.getSuccessors());
837}
838void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
841 OperationName(op.getName(), ctx), op.getSuccessors());
842}
843void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField>(op.getCompareAtLeast()),
847 op.getSuccessors());
848}
849void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851 op.getSuccessors());
852}
853void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855 op.getSuccessors());
856}
857void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860}
861void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
863 // Simply repoint the memory index of the result to the constant.
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865}
866void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
869 OperationName(op.getName(), ctx));
870 writer.appendPDLValueList(op.getInputOperands());
871
872 // Add the attributes.
873 OperandRange attributes = op.getInputAttributes();
874 writer.append(static_cast<ByteCodeField>(attributes.size()));
875 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
877
878 // Add the result types. If the operation has inferred results, we use a
879 // marker "size" value. Otherwise, we add the list of explicit result types.
880 if (op.getInferredResultTypes())
881 writer.append(kInferTypesMarker);
882 else
883 writer.appendPDLValueList(op.getInputResultTypes());
884}
885void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886 // Append the correct opcode for the range type.
887 TypeSwitch<Type>(op.getType().getElementType())
888 .Case(
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
892 });
893
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
896}
897void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898 // Simply repoint the memory index of the result to the constant.
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900}
901void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
904}
905void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
907}
908void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909 OpCode opCode =
910 TypeSwitch<Type, OpCode>(op.getResult().getType())
911 .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914 .DefaultUnreachable("unsupported element type");
915 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
916}
917void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918 writer.append(OpCode::Finalize);
919}
920void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
921 BlockArgument arg = op.getLoopVariable();
922 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923 writer.appendPDLValueKind(arg.getType());
924 writer.append(curLoopLevel, op.getSuccessor());
925 ++curLoopLevel;
926 if (curLoopLevel > maxLoopLevel)
927 maxLoopLevel = curLoopLevel;
928 generate(&op.getRegion(), writer);
929 --curLoopLevel;
930}
931void Generator::generate(pdl_interp::GetAttributeOp op,
932 ByteCodeWriter &writer) {
933 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
934 op.getNameAttr());
935}
936void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937 ByteCodeWriter &writer) {
938 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
939}
940void Generator::generate(pdl_interp::GetDefiningOpOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetDefiningOp, op.getInputOp());
943 writer.appendPDLValue(op.getValue());
944}
945void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946 uint32_t index = op.getIndex();
947 if (index < 4)
948 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
949 else
950 writer.append(OpCode::GetOperandN, index);
951 writer.append(op.getInputOp(), op.getValue());
952}
953void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954 Value result = op.getValue();
955 std::optional<uint32_t> index = op.getIndex();
956 writer.append(OpCode::GetOperands,
957 index.value_or(std::numeric_limits<uint32_t>::max()),
958 op.getInputOp());
959 if (isa<pdl::RangeType>(result.getType()))
960 writer.append(getRangeStorageIndex(result));
961 else
962 writer.append(std::numeric_limits<ByteCodeField>::max());
963 writer.append(result);
964}
965void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966 uint32_t index = op.getIndex();
967 if (index < 4)
968 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
969 else
970 writer.append(OpCode::GetResultN, index);
971 writer.append(op.getInputOp(), op.getValue());
972}
973void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974 Value result = op.getValue();
975 std::optional<uint32_t> index = op.getIndex();
976 writer.append(OpCode::GetResults,
977 index.value_or(std::numeric_limits<uint32_t>::max()),
978 op.getInputOp());
979 if (isa<pdl::RangeType>(result.getType()))
980 writer.append(getRangeStorageIndex(result));
981 else
982 writer.append(std::numeric_limits<ByteCodeField>::max());
983 writer.append(result);
984}
985void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986 Value operations = op.getOperations();
987 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988 writer.append(OpCode::GetUsers, operations, rangeIndex);
989 writer.appendPDLValue(op.getValue());
990}
991void Generator::generate(pdl_interp::GetValueTypeOp op,
992 ByteCodeWriter &writer) {
993 if (isa<pdl::RangeType>(op.getType())) {
994 Value result = op.getResult();
995 writer.append(OpCode::GetValueRangeTypes, result,
996 getRangeStorageIndex(result), op.getValue());
997 } else {
998 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
999 }
1000}
1001void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1003}
1004void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005 ByteCodeField patternIndex = patterns.size();
1006 patterns.emplace_back(PDLByteCodePattern::create(
1007 op, configMap.lookup(op),
1008 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009 writer.append(OpCode::RecordMatch, patternIndex,
1010 SuccessorRange(op.getOperation()), op.getMatchedOps());
1011 writer.appendPDLValueList(op.getInputs());
1012}
1013void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014 writer.append(OpCode::ReplaceOp, op.getInputOp());
1015 writer.appendPDLValueList(op.getReplValues());
1016}
1017void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018 ByteCodeWriter &writer) {
1019 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020 op.getCaseValuesAttr(), op.getSuccessors());
1021}
1022void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023 ByteCodeWriter &writer) {
1024 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025 op.getCaseValuesAttr(), op.getSuccessors());
1026}
1027void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028 ByteCodeWriter &writer) {
1029 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1030 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1031 });
1032 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033 op.getSuccessors());
1034}
1035void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036 ByteCodeWriter &writer) {
1037 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038 op.getCaseValuesAttr(), op.getSuccessors());
1039}
1040void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042 op.getSuccessors());
1043}
1044void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// PDLByteCode
1051//===----------------------------------------------------------------------===//
1052
1053PDLByteCode::PDLByteCode(
1054 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056 llvm::StringMap<PDLConstraintFunction> constraintFns,
1057 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058 : configs(std::move(configs)) {
1059 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1060 rewriterByteCode, patterns, maxValueMemoryIndex,
1061 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062 maxLoopLevel, constraintFns, rewriteFns, configMap);
1063 generator.generate(module);
1064
1065 // Initialize the external functions.
1066 for (auto &it : constraintFns)
1067 constraintFunctions.push_back(std::move(it.second));
1068 for (auto &it : rewriteFns)
1069 rewriteFunctions.push_back(std::move(it.second));
1070}
1071
1072/// Initialize the given state such that it can be used to execute the current
1073/// bytecode.
1075 state.memory.resize(maxValueMemoryIndex, nullptr);
1076 state.opRangeMemory.resize(maxOpRangeCount);
1077 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1078 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1079 state.loopIndex.resize(maxLoopLevel, 0);
1080 state.currentPatternBenefits.reserve(patterns.size());
1081 for (const PDLByteCodePattern &pattern : patterns)
1082 state.currentPatternBenefits.push_back(pattern.getBenefit());
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// ByteCode Execution
1087//===----------------------------------------------------------------------===//
1088
1089namespace {
1090/// This class is an instantiation of the PDLResultList that provides access to
1091/// the returned results. This API is not on `PDLResultList` to avoid
1092/// overexposing access to information specific solely to the ByteCode.
1093class ByteCodeRewriteResultList : public PDLResultList {
1094public:
1095 ByteCodeRewriteResultList(unsigned maxNumResults)
1096 : PDLResultList(maxNumResults) {}
1097
1098 /// Return the list of PDL results.
1099 MutableArrayRef<PDLValue> getResults() { return results; }
1100
1101 /// Return the type ranges allocated by this list.
1102 MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() {
1103 return allocatedTypeRanges;
1104 }
1105
1106 /// Return the value ranges allocated by this list.
1107 MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() {
1108 return allocatedValueRanges;
1109 }
1110};
1111
1112/// This class provides support for executing a bytecode stream.
1113class ByteCodeExecutor {
1114public:
1115 ByteCodeExecutor(const ByteCodeField *curCodeIt,
1116 MutableArrayRef<const void *> memory,
1117 MutableArrayRef<std::vector<Operation *>> opRangeMemory,
1118 MutableArrayRef<TypeRange> typeRangeMemory,
1119 std::vector<std::vector<Type>> &allocatedTypeRangeMemory,
1120 MutableArrayRef<ValueRange> valueRangeMemory,
1121 std::vector<std::vector<Value>> &allocatedValueRangeMemory,
1122 MutableArrayRef<unsigned> loopIndex,
1123 ArrayRef<const void *> uniquedMemory,
1124 ArrayRef<ByteCodeField> code,
1125 ArrayRef<PatternBenefit> currentPatternBenefits,
1126 ArrayRef<PDLByteCodePattern> patterns,
1127 ArrayRef<PDLConstraintFunction> constraintFunctions,
1128 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130 typeRangeMemory(typeRangeMemory),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132 valueRangeMemory(valueRangeMemory),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory),
1134 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1136 constraintFunctions(constraintFunctions),
1137 rewriteFunctions(rewriteFunctions) {}
1138
1139 /// Start executing the code at the current bytecode index. `matches` is an
1140 /// optional field provided when this function is executed in a matching
1141 /// context.
1142 LogicalResult
1143 execute(PatternRewriter &rewriter,
1144 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1145 std::optional<Location> mainRewriteLoc = {});
1146
1147private:
1148 /// Internal implementation of executing each of the bytecode commands.
1149 void executeApplyConstraint(PatternRewriter &rewriter);
1150 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1160 void executeCreateOperation(PatternRewriter &rewriter,
1161 Location mainRewriteLoc);
1162 template <typename T>
1163 void executeDynamicCreateRange(StringRef type);
1164 void executeEraseOp(PatternRewriter &rewriter);
1165 template <typename T, typename Range, PDLValue::Kind kind>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(unsigned index);
1173 void executeGetOperands();
1174 void executeGetResult(unsigned index);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1180 void executeRecordMatch(PatternRewriter &rewriter,
1181 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182 void executeReplaceOp(PatternRewriter &rewriter);
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList &results,
1190 unsigned numResults,
1191 LogicalResult &rewriteResult);
1192
1193 /// Pushes a code iterator to the stack.
1194 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1195
1196 /// Pops a code iterator from the stack, returning true on success.
1197 void popCodeIt() {
1198 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1199 curCodeIt = resumeCodeIt.pop_back_val();
1200 }
1201
1202 /// Return the bytecode iterator at the start of the current op code.
1203 const ByteCodeField *getPrevCodeIt() const {
1204 LLVM_DEBUG({
1205 // Account for the op code and the Location stored inline.
1206 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1207 });
1208
1209 // Account for the op code only.
1210 return curCodeIt - 1;
1211 }
1212
1213 /// Read a value from the bytecode buffer, optionally skipping a certain
1214 /// number of prefix values. These methods always update the buffer to point
1215 /// to the next field after the read data.
1216 template <typename T = ByteCodeField>
1217 T read(size_t skipN = 0) {
1218 curCodeIt += skipN;
1219 return readImpl<T>();
1220 }
1221 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1222
1223 /// Read a list of values from the bytecode buffer.
1224 template <typename ValueT, typename T>
1225 void readList(SmallVectorImpl<T> &list) {
1226 list.clear();
1227 for (unsigned i = 0, e = read(); i != e; ++i)
1228 list.push_back(read<ValueT>());
1229 }
1230
1231 /// Read a list of values from the bytecode buffer. The values may be encoded
1232 /// either as a single element or a range of elements.
1233 void readList(SmallVectorImpl<Type> &list) {
1234 for (unsigned i = 0, e = read(); i != e; ++i) {
1235 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1236 list.push_back(read<Type>());
1237 } else {
1238 TypeRange *values = read<TypeRange *>();
1239 list.append(values->begin(), values->end());
1240 }
1241 }
1242 }
1243 void readList(SmallVectorImpl<Value> &list) {
1244 for (unsigned i = 0, e = read(); i != e; ++i) {
1245 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1246 list.push_back(read<Value>());
1247 } else {
1248 ValueRange *values = read<ValueRange *>();
1249 list.append(values->begin(), values->end());
1250 }
1251 }
1252 }
1253
1254 /// Read a value stored inline as a pointer.
1255 template <typename T>
1256 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1257 readInline() {
1258 const void *pointer;
1259 std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1260 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1261 return T::getFromOpaquePointer(pointer);
1262 }
1263
1264 void skip(size_t skipN) { curCodeIt += skipN; }
1265
1266 /// Jump to a specific successor based on a predicate value.
1267 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1268 /// Jump to a specific successor based on a destination index.
1269 void selectJump(size_t destIndex) {
1270 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1271 }
1272
1273 /// Handle a switch operation with the provided value and cases.
1274 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1275 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1276 LDBG() << "Switch operation:\n * Value: " << value
1277 << "\n * Cases: " << llvm::interleaved(cases);
1278
1279 // Check to see if the attribute value is within the case list. Jump to
1280 // the correct successor index based on the result.
1281 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1282 if (cmp(*it, value))
1283 return selectJump(size_t((it - cases.begin()) + 1));
1284 selectJump(size_t(0));
1285 }
1286
1287 /// Store a pointer to memory.
1288 void storeToMemory(unsigned index, const void *value) {
1289 memory[index] = value;
1290 }
1291
1292 /// Store a value to memory as an opaque pointer.
1293 template <typename T>
1294 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1295 storeToMemory(unsigned index, T value) {
1296 memory[index] = value.getAsOpaquePointer();
1297 }
1298
1299 /// Internal implementation of reading various data types from the bytecode
1300 /// stream.
1301 template <typename T>
1302 const void *readFromMemory() {
1303 size_t index = *curCodeIt++;
1304
1305 // If this type is an SSA value, it can only be stored in non-const memory.
1306 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1307 Value>::value ||
1308 index < memory.size())
1309 return memory[index];
1310
1311 // Otherwise, if this index is not inbounds it is uniqued.
1312 return uniquedMemory[index - memory.size()];
1313 }
1314 template <typename T>
1315 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1316 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1317 }
1318 template <typename T>
1319 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1320 T>
1321 readImpl() {
1322 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1323 }
1324 template <typename T>
1325 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1326 switch (read<PDLValue::Kind>()) {
1327 case PDLValue::Kind::Attribute:
1328 return read<Attribute>();
1329 case PDLValue::Kind::Operation:
1330 return read<Operation *>();
1331 case PDLValue::Kind::Type:
1332 return read<Type>();
1333 case PDLValue::Kind::Value:
1334 return read<Value>();
1335 case PDLValue::Kind::TypeRange:
1336 return read<TypeRange *>();
1337 case PDLValue::Kind::ValueRange:
1338 return read<ValueRange *>();
1339 }
1340 llvm_unreachable("unhandled PDLValue::Kind");
1341 }
1342 template <typename T>
1343 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1344 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1345 "unexpected ByteCode address size");
1346 ByteCodeAddr result;
1347 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1348 curCodeIt += 2;
1349 return result;
1350 }
1351 template <typename T>
1352 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1353 return *curCodeIt++;
1354 }
1355 template <typename T>
1356 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1357 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1358 }
1359
1360 /// Assign the given range to the given memory index. This allocates a new
1361 /// range object if necessary.
1362 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1363 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1364 unsigned rangeIndex) {
1365 // Utility functor used to type-erase the assignment.
1366 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1367 // If the input range is empty, we don't need to allocate anything.
1368 if (range.empty()) {
1369 rangeMemory[rangeIndex] = {};
1370 } else {
1371 // Assign this to the range slot and use the range as the value for the
1372 // memory index.
1373 allocatedRangeMemory.emplace_back(range.begin(), range.end());
1374 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1375 }
1376 memory[memIndex] = &rangeMemory[rangeIndex];
1377 };
1378
1379 // Dispatch based on the concrete range type.
1380 if constexpr (std::is_same_v<T, Type>) {
1381 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1382 } else if constexpr (std::is_same_v<T, Value>) {
1383 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1384 } else {
1385 llvm_unreachable("unhandled range type");
1386 }
1387 }
1388
1389 /// The underlying bytecode buffer.
1390 const ByteCodeField *curCodeIt;
1391
1392 /// The stack of bytecode positions at which to resume operation.
1393 SmallVector<const ByteCodeField *> resumeCodeIt;
1394
1395 /// The current execution memory.
1396 MutableArrayRef<const void *> memory;
1397 MutableArrayRef<std::vector<Operation *>> opRangeMemory;
1398 MutableArrayRef<TypeRange> typeRangeMemory;
1399 std::vector<std::vector<Type>> &allocatedTypeRangeMemory;
1400 MutableArrayRef<ValueRange> valueRangeMemory;
1401 std::vector<std::vector<Value>> &allocatedValueRangeMemory;
1402
1403 /// The current loop indices.
1404 MutableArrayRef<unsigned> loopIndex;
1405
1406 /// References to ByteCode data necessary for execution.
1407 ArrayRef<const void *> uniquedMemory;
1408 ArrayRef<ByteCodeField> code;
1409 ArrayRef<PatternBenefit> currentPatternBenefits;
1410 ArrayRef<PDLByteCodePattern> patterns;
1411 ArrayRef<PDLConstraintFunction> constraintFunctions;
1412 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1413};
1414} // namespace
1415
1416void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1417 LDBG() << "Executing ApplyConstraint:";
1418 ByteCodeField fun_idx = read();
1419 SmallVector<PDLValue, 16> args;
1420 readList<PDLValue>(args);
1421
1422 LDBG() << " * Arguments: " << llvm::interleaved(args);
1423
1424 ByteCodeField isNegated = read();
1425 LDBG() << " * isNegated: " << isNegated;
1426
1427 ByteCodeField numResults = read();
1428 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1429 ByteCodeRewriteResultList results(numResults);
1430 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1431 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1432 if (succeeded(rewriteResult)) {
1433 LDBG() << " * Constraint succeeded, results: "
1434 << llvm::interleaved(constraintResults);
1435 } else {
1436 LDBG() << " * Constraint failed";
1437 }
1438 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1439 "native PDL rewrite function succeeded but returned "
1440 "unexpected number of results");
1441 processNativeFunResults(results, numResults, rewriteResult);
1442
1443 // Depending on the constraint jump to the proper destination.
1444 selectJump(isNegated != succeeded(rewriteResult));
1445}
1446
1447LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1448 LDBG() << "Executing ApplyRewrite:";
1449 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1450 SmallVector<PDLValue, 16> args;
1451 readList<PDLValue>(args);
1452
1453 LDBG() << " * Arguments: " << llvm::interleaved(args);
1454
1455 // Execute the rewrite function.
1456 ByteCodeField numResults = read();
1457 ByteCodeRewriteResultList results(numResults);
1458 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1459
1460 assert(results.getResults().size() == numResults &&
1461 "native PDL rewrite function returned unexpected number of results");
1462
1463 processNativeFunResults(results, numResults, rewriteResult);
1464
1465 if (failed(rewriteResult)) {
1466 LDBG() << " - Failed";
1467 return failure();
1468 }
1469 return success();
1470}
1471
1472void ByteCodeExecutor::processNativeFunResults(
1473 ByteCodeRewriteResultList &results, unsigned numResults,
1474 LogicalResult &rewriteResult) {
1475 if (failed(rewriteResult)) {
1476 // Skip the according number of values on the buffer on failure and exit
1477 // early as there are no results to process.
1478 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1479 const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1480 if (resultKind == PDLValue::Kind::TypeRange ||
1481 resultKind == PDLValue::Kind::ValueRange) {
1482 skip(2);
1483 } else {
1484 skip(1);
1485 }
1486 }
1487 return;
1488 }
1489
1490 // Store the results in the bytecode memory
1491 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1492 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1493 (void)resultKind;
1494 PDLValue result = results.getResults()[resultIdx];
1495 LDBG() << " * Result: " << result;
1496 assert(result.getKind() == resultKind &&
1497 "native PDL rewrite function returned an unexpected type of "
1498 "result");
1499 // If the result is a range, we need to copy it over to the bytecodes
1500 // range memory.
1501 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1502 unsigned rangeIndex = read();
1503 typeRangeMemory[rangeIndex] = *typeRange;
1504 memory[read()] = &typeRangeMemory[rangeIndex];
1505 } else if (std::optional<ValueRange> valueRange =
1506 result.dyn_cast<ValueRange>()) {
1507 unsigned rangeIndex = read();
1508 valueRangeMemory[rangeIndex] = *valueRange;
1509 memory[read()] = &valueRangeMemory[rangeIndex];
1510 } else {
1511 memory[read()] = result.getAsOpaquePointer();
1512 }
1513 }
1514
1515 // Copy over any underlying storage allocated for result ranges.
1516 for (auto &it : results.getAllocatedTypeRanges())
1517 allocatedTypeRangeMemory.push_back(std::move(it));
1518 for (auto &it : results.getAllocatedValueRanges())
1519 allocatedValueRangeMemory.push_back(std::move(it));
1520}
1521
1522void ByteCodeExecutor::executeAreEqual() {
1523 LDBG() << "Executing AreEqual:";
1524 const void *lhs = read<const void *>();
1525 const void *rhs = read<const void *>();
1526
1527 LDBG() << " * " << lhs << " == " << rhs;
1528 selectJump(lhs == rhs);
1529}
1530
1531void ByteCodeExecutor::executeAreRangesEqual() {
1532 LDBG() << "Executing AreRangesEqual:";
1533 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1534 const void *lhs = read<const void *>();
1535 const void *rhs = read<const void *>();
1536
1537 switch (valueKind) {
1538 case PDLValue::Kind::TypeRange: {
1539 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1540 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1541 LDBG() << " * " << lhs << " == " << rhs;
1542 selectJump(*lhsRange == *rhsRange);
1543 break;
1544 }
1545 case PDLValue::Kind::ValueRange: {
1546 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1547 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1548 LDBG() << " * " << lhs << " == " << rhs;
1549 selectJump(*lhsRange == *rhsRange);
1550 break;
1551 }
1552 default:
1553 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1554 }
1555}
1556
1557void ByteCodeExecutor::executeBranch() {
1558 LDBG() << "Executing Branch";
1559 curCodeIt = &code[read<ByteCodeAddr>()];
1560}
1561
1562void ByteCodeExecutor::executeCheckOperandCount() {
1563 LDBG() << "Executing CheckOperandCount:";
1564 Operation *op = read<Operation *>();
1565 uint32_t expectedCount = read<uint32_t>();
1566 bool compareAtLeast = read();
1567
1568 LDBG() << " * Found: " << op->getNumOperands()
1569 << "\n * Expected: " << expectedCount
1570 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1571 if (compareAtLeast)
1572 selectJump(op->getNumOperands() >= expectedCount);
1573 else
1574 selectJump(op->getNumOperands() == expectedCount);
1575}
1576
1577void ByteCodeExecutor::executeCheckOperationName() {
1578 LDBG() << "Executing CheckOperationName:";
1579 Operation *op = read<Operation *>();
1580 OperationName expectedName = read<OperationName>();
1581
1582 LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
1583 << expectedName << "\"";
1584 selectJump(op->getName() == expectedName);
1585}
1586
1587void ByteCodeExecutor::executeCheckResultCount() {
1588 LDBG() << "Executing CheckResultCount:";
1589 Operation *op = read<Operation *>();
1590 uint32_t expectedCount = read<uint32_t>();
1591 bool compareAtLeast = read();
1592
1593 LDBG() << " * Found: " << op->getNumResults()
1594 << "\n * Expected: " << expectedCount
1595 << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1596 if (compareAtLeast)
1597 selectJump(op->getNumResults() >= expectedCount);
1598 else
1599 selectJump(op->getNumResults() == expectedCount);
1600}
1601
1602void ByteCodeExecutor::executeCheckTypes() {
1603 LDBG() << "Executing AreEqual:";
1604 TypeRange *lhs = read<TypeRange *>();
1605 Attribute rhs = read<Attribute>();
1606 LDBG() << " * " << lhs << " == " << rhs;
1607
1608 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1609}
1610
1611void ByteCodeExecutor::executeContinue() {
1612 ByteCodeField level = read();
1613 LDBG() << "Executing Continue\n * Level: " << level;
1614 ++loopIndex[level];
1615 popCodeIt();
1616}
1617
1618void ByteCodeExecutor::executeCreateConstantTypeRange() {
1619 LDBG() << "Executing CreateConstantTypeRange:";
1620 unsigned memIndex = read();
1621 unsigned rangeIndex = read();
1622 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1623
1624 LDBG() << " * Types: " << typesAttr;
1625 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1626 rangeIndex);
1627}
1628
1629void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1630 Location mainRewriteLoc) {
1631 LDBG() << "Executing CreateOperation:";
1632
1633 unsigned memIndex = read();
1634 OperationState state(mainRewriteLoc, read<OperationName>());
1635 readList(state.operands);
1636 for (unsigned i = 0, e = read(); i != e; ++i) {
1637 StringAttr name = read<StringAttr>();
1638 if (Attribute attr = read<Attribute>())
1639 state.addAttribute(name, attr);
1640 }
1641
1642 // Read in the result types. If the "size" is the sentinel value, this
1643 // indicates that the result types should be inferred.
1644 unsigned numResults = read();
1645 if (numResults == kInferTypesMarker) {
1646 InferTypeOpInterface::Concept *inferInterface =
1647 state.name.getInterface<InferTypeOpInterface>();
1648 assert(inferInterface &&
1649 "expected operation to provide InferTypeOpInterface");
1650
1651 // TODO: Handle failure.
1652 if (failed(inferInterface->inferReturnTypes(
1653 state.getContext(), state.location, state.operands,
1654 state.attributes.getDictionary(state.getContext()),
1655 state.getRawProperties(), state.regions, state.types)))
1656 return;
1657 } else {
1658 // Otherwise, this is a fixed number of results.
1659 for (unsigned i = 0; i != numResults; ++i) {
1660 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1661 state.types.push_back(read<Type>());
1662 } else {
1663 TypeRange *resultTypes = read<TypeRange *>();
1664 state.types.append(resultTypes->begin(), resultTypes->end());
1665 }
1666 }
1667 }
1668
1669 Operation *resultOp = rewriter.create(state);
1670 memory[memIndex] = resultOp;
1671
1672 LDBG() << " * Attributes: "
1673 << state.attributes.getDictionary(state.getContext())
1674 << "\n * Operands: " << llvm::interleaved(state.operands)
1675 << "\n * Result Types: " << llvm::interleaved(state.types)
1676 << "\n * Result: " << *resultOp;
1677}
1678
1679template <typename T>
1680void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1681 LDBG() << "Executing CreateDynamic" << type << "Range:";
1682 unsigned memIndex = read();
1683 unsigned rangeIndex = read();
1684 SmallVector<T> values;
1685 readList(values);
1686
1687 LDBG() << " * " << type << "s: " << llvm::interleaved(values);
1688
1689 assignRangeToMemory(values, memIndex, rangeIndex);
1690}
1691
1692void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1693 LDBG() << "Executing EraseOp:";
1694 Operation *op = read<Operation *>();
1695
1696 LDBG() << " * Operation: " << *op;
1697 rewriter.eraseOp(op);
1698}
1699
1700template <typename T, typename Range, PDLValue::Kind kind>
1701void ByteCodeExecutor::executeExtract() {
1702 LDBG() << "Executing Extract" << kind << ":";
1703 Range *range = read<Range *>();
1704 unsigned index = read<uint32_t>();
1705 unsigned memIndex = read();
1706
1707 if (!range) {
1708 memory[memIndex] = nullptr;
1709 return;
1710 }
1711
1712 T result = index < range->size() ? (*range)[index] : T();
1713 LDBG() << " * " << kind << "s(" << range->size() << ")";
1714 LDBG() << " * Index: " << index;
1715 LDBG() << " * Result: " << result;
1716 storeToMemory(memIndex, result);
1717}
1718
1719void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
1720
1721void ByteCodeExecutor::executeForEach() {
1722 LDBG() << "Executing ForEach:";
1723 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1724 unsigned rangeIndex = read();
1725 unsigned memIndex = read();
1726 const void *value = nullptr;
1727
1728 switch (read<PDLValue::Kind>()) {
1729 case PDLValue::Kind::Operation: {
1730 unsigned &index = loopIndex[read()];
1731 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1732 assert(index <= array.size() && "iterated past the end");
1733 if (index < array.size()) {
1734 LDBG() << " * Result: " << array[index];
1735 value = array[index];
1736 break;
1737 }
1738
1739 LDBG() << " * Done";
1740 index = 0;
1741 selectJump(size_t(0));
1742 return;
1743 }
1744 case PDLValue::Kind::Value: {
1745 unsigned &index = loopIndex[read()];
1746 ValueRange range = valueRangeMemory[rangeIndex];
1747 assert(index <= range.size() && "iterated past the end");
1748 if (index < range.size()) {
1749 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
1750 value = range[index].getAsOpaquePointer();
1751 break;
1752 }
1753
1754 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1755 index = 0;
1756 selectJump(size_t(0));
1757 return;
1758 }
1759 case PDLValue::Kind::Type: {
1760 unsigned &index = loopIndex[read()];
1761 TypeRange range = typeRangeMemory[rangeIndex];
1762 assert(index <= range.size() && "iterated past the end");
1763 if (index < range.size()) {
1764 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
1765 value = range[index].getAsOpaquePointer();
1766 break;
1767 }
1768
1769 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1770 index = 0;
1771 selectJump(size_t(0));
1772 return;
1773 }
1774 default:
1775 llvm_unreachable("unexpected `ForEach` value kind");
1776 }
1777
1778 // Store the iterate value and the stack address.
1779 memory[memIndex] = value;
1780 pushCodeIt(prevCodeIt);
1781
1782 // Skip over the successor (we will enter the body of the loop).
1783 read<ByteCodeAddr>();
1784}
1785
1786void ByteCodeExecutor::executeGetAttribute() {
1787 LDBG() << "Executing GetAttribute:";
1788 unsigned memIndex = read();
1789 Operation *op = read<Operation *>();
1790 StringAttr attrName = read<StringAttr>();
1791 Attribute attr = op->getAttr(attrName);
1792
1793 LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
1794 << "\n * Result: " << attr;
1795 memory[memIndex] = attr.getAsOpaquePointer();
1796}
1797
1798void ByteCodeExecutor::executeGetAttributeType() {
1799 LDBG() << "Executing GetAttributeType:";
1800 unsigned memIndex = read();
1801 Attribute attr = read<Attribute>();
1802 Type type;
1803 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1804 type = typedAttr.getType();
1805
1806 LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
1807 memory[memIndex] = type.getAsOpaquePointer();
1808}
1809
1810void ByteCodeExecutor::executeGetDefiningOp() {
1811 LDBG() << "Executing GetDefiningOp:";
1812 unsigned memIndex = read();
1813 Operation *op = nullptr;
1814 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1815 Value value = read<Value>();
1816 if (value)
1817 op = value.getDefiningOp();
1818 LDBG() << " * Value: " << value;
1819 } else {
1820 ValueRange *values = read<ValueRange *>();
1821 if (values && !values->empty()) {
1822 op = values->front().getDefiningOp();
1823 }
1824 LDBG() << " * Values: " << values;
1825 }
1826
1827 LDBG() << " * Result: " << op;
1828 memory[memIndex] = op;
1829}
1830
1831void ByteCodeExecutor::executeGetOperand(unsigned index) {
1832 Operation *op = read<Operation *>();
1833 unsigned memIndex = read();
1834 Value operand =
1835 index < op->getNumOperands() ? op->getOperand(index) : Value();
1836
1837 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1838 << "\n * Result: " << operand;
1839 memory[memIndex] = operand.getAsOpaquePointer();
1840}
1841
1842/// This function is the internal implementation of `GetResults` and
1843/// `GetOperands` that provides support for extracting a value range from the
1844/// given operation.
1845template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1846static void *
1847executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1848 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1849 MutableArrayRef<ValueRange> valueRangeMemory) {
1850 // Check for the sentinel index that signals that all values should be
1851 // returned.
1852 if (index == std::numeric_limits<uint32_t>::max()) {
1853 LDBG() << " * Getting all values";
1854 // `values` is already the full value range.
1855
1856 // Otherwise, check to see if this operation uses AttrSizedSegments.
1857 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1858 LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
1859
1860 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1861 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1862 return nullptr;
1863
1864 ArrayRef<int32_t> segments = segmentAttr;
1865 unsigned startIndex = llvm::sum_of(segments.take_front(index));
1866 values = values.slice(startIndex, *std::next(segments.begin(), index));
1867
1868 LDBG() << " * Extracting range[" << startIndex << ", "
1869 << *std::next(segments.begin(), index) << "]";
1870
1871 // Otherwise, assume this is the last operand group of the operation.
1872 // FIXME: We currently don't support operations with
1873 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1874 // have a way to detect it's presence.
1875 } else if (values.size() >= index) {
1876 LDBG() << " * Treating values as trailing variadic range";
1877 values = values.drop_front(index);
1878
1879 // If we couldn't detect a way to compute the values, bail out.
1880 } else {
1881 return nullptr;
1882 }
1883
1884 // If the range index is valid, we are returning a range.
1885 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1886 valueRangeMemory[rangeIndex] = values;
1887 return &valueRangeMemory[rangeIndex];
1888 }
1889
1890 // If a range index wasn't provided, the range is required to be non-variadic.
1891 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1892}
1893
1894void ByteCodeExecutor::executeGetOperands() {
1895 LDBG() << "Executing GetOperands:";
1896 unsigned index = read<uint32_t>();
1897 Operation *op = read<Operation *>();
1898 ByteCodeField rangeIndex = read();
1899
1901 op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1902 valueRangeMemory);
1903 if (!result)
1904 LDBG() << " * Invalid operand range";
1905 memory[read()] = result;
1906}
1907
1908void ByteCodeExecutor::executeGetResult(unsigned index) {
1909 Operation *op = read<Operation *>();
1910 unsigned memIndex = read();
1911 OpResult result =
1912 index < op->getNumResults() ? op->getResult(index) : OpResult();
1913
1914 LDBG() << " * Operation: " << *op << "\n * Index: " << index
1915 << "\n * Result: " << result;
1916 memory[memIndex] = result.getAsOpaquePointer();
1917}
1918
1919void ByteCodeExecutor::executeGetResults() {
1920 LDBG() << "Executing GetResults:";
1921 unsigned index = read<uint32_t>();
1922 Operation *op = read<Operation *>();
1923 ByteCodeField rangeIndex = read();
1924
1926 op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1927 valueRangeMemory);
1928 if (!result)
1929 LDBG() << " * Invalid result range";
1930 memory[read()] = result;
1931}
1932
1933void ByteCodeExecutor::executeGetUsers() {
1934 LDBG() << "Executing GetUsers:";
1935 unsigned memIndex = read();
1936 unsigned rangeIndex = read();
1937 std::vector<Operation *> &range = opRangeMemory[rangeIndex];
1938 memory[memIndex] = &range;
1939
1940 range.clear();
1941 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1942 // Read the value.
1943 Value value = read<Value>();
1944 if (!value)
1945 return;
1946 LDBG() << " * Value: " << value;
1947
1948 range.assign(value.user_begin(), value.user_end());
1949 } else {
1950 // Read a range of values.
1951 ValueRange *values = read<ValueRange *>();
1952 if (!values)
1953 return;
1954 LDBG() << " * Values (" << values->size()
1955 << "): " << llvm::interleaved(*values);
1956
1957 for (Value value : *values)
1958 range.insert(range.end(), value.user_begin(), value.user_end());
1959 }
1960
1961 LDBG() << " * Result: " << range.size() << " operations";
1962}
1963
1964void ByteCodeExecutor::executeGetValueType() {
1965 LDBG() << "Executing GetValueType:";
1966 unsigned memIndex = read();
1967 Value value = read<Value>();
1968 Type type = value ? value.getType() : Type();
1969
1970 LDBG() << " * Value: " << value << "\n * Result: " << type;
1971 memory[memIndex] = type.getAsOpaquePointer();
1972}
1973
1974void ByteCodeExecutor::executeGetValueRangeTypes() {
1975 LDBG() << "Executing GetValueRangeTypes:";
1976 unsigned memIndex = read();
1977 unsigned rangeIndex = read();
1978 ValueRange *values = read<ValueRange *>();
1979 if (!values) {
1980 LDBG() << " * Values: <NULL>";
1981 memory[memIndex] = nullptr;
1982 return;
1983 }
1984
1985 LDBG() << " * Values (" << values->size()
1986 << "): " << llvm::interleaved(*values)
1987 << "\n * Result: " << llvm::interleaved(values->getType());
1988 typeRangeMemory[rangeIndex] = values->getType();
1989 memory[memIndex] = &typeRangeMemory[rangeIndex];
1990}
1991
1992void ByteCodeExecutor::executeIsNotNull() {
1993 LDBG() << "Executing IsNotNull:";
1994 const void *value = read<const void *>();
1995
1996 LDBG() << " * Value: " << value;
1997 selectJump(value != nullptr);
1998}
1999
2000void ByteCodeExecutor::executeRecordMatch(
2001 PatternRewriter &rewriter,
2002 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
2003 LDBG() << "Executing RecordMatch:";
2004 unsigned patternIndex = read();
2005 PatternBenefit benefit = currentPatternBenefits[patternIndex];
2006 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2007
2008 // If the benefit of the pattern is impossible, skip the processing of the
2009 // rest of the pattern.
2010 if (benefit.isImpossibleToMatch()) {
2011 LDBG() << " * Benefit: Impossible To Match";
2012 curCodeIt = dest;
2013 return;
2014 }
2015
2016 // Create a fused location containing the locations of each of the
2017 // operations used in the match. This will be used as the location for
2018 // created operations during the rewrite that don't already have an
2019 // explicit location set.
2020 unsigned numMatchLocs = read();
2021 SmallVector<Location, 4> matchLocs;
2022 matchLocs.reserve(numMatchLocs);
2023 for (unsigned i = 0; i != numMatchLocs; ++i)
2024 matchLocs.push_back(read<Operation *>()->getLoc());
2025 Location matchLoc = rewriter.getFusedLoc(matchLocs);
2026
2027 LDBG() << " * Benefit: " << benefit.getBenefit();
2028 LDBG() << " * Location: " << matchLoc;
2029 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2030 PDLByteCode::MatchResult &match = matches.back();
2031
2032 // Record all of the inputs to the match. If any of the inputs are ranges, we
2033 // will also need to remap the range pointer to memory stored in the match
2034 // state.
2035 unsigned numInputs = read();
2036 match.values.reserve(numInputs);
2037 match.typeRangeValues.reserve(numInputs);
2038 match.valueRangeValues.reserve(numInputs);
2039 for (unsigned i = 0; i < numInputs; ++i) {
2040 switch (read<PDLValue::Kind>()) {
2041 case PDLValue::Kind::TypeRange:
2042 match.typeRangeValues.push_back(*read<TypeRange *>());
2043 match.values.push_back(&match.typeRangeValues.back());
2044 break;
2045 case PDLValue::Kind::ValueRange:
2046 match.valueRangeValues.push_back(*read<ValueRange *>());
2047 match.values.push_back(&match.valueRangeValues.back());
2048 break;
2049 default:
2050 match.values.push_back(read<const void *>());
2051 break;
2052 }
2053 }
2054 curCodeIt = dest;
2055}
2056
2057void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2058 LDBG() << "Executing ReplaceOp:";
2059 Operation *op = read<Operation *>();
2060 SmallVector<Value, 16> args;
2061 readList(args);
2062
2063 LDBG() << " * Operation: " << *op
2064 << "\n * Values: " << llvm::interleaved(args);
2065 rewriter.replaceOp(op, args);
2066}
2067
2068void ByteCodeExecutor::executeSwitchAttribute() {
2069 LDBG() << "Executing SwitchAttribute:";
2070 Attribute value = read<Attribute>();
2071 ArrayAttr cases = read<ArrayAttr>();
2072 handleSwitch(value, cases);
2073}
2074
2075void ByteCodeExecutor::executeSwitchOperandCount() {
2076 LDBG() << "Executing SwitchOperandCount:";
2077 Operation *op = read<Operation *>();
2078 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2079
2080 LDBG() << " * Operation: " << *op;
2081 handleSwitch(op->getNumOperands(), cases);
2082}
2083
2084void ByteCodeExecutor::executeSwitchOperationName() {
2085 LDBG() << "Executing SwitchOperationName:";
2086 OperationName value = read<Operation *>()->getName();
2087 size_t caseCount = read();
2088
2089 // The operation names are stored in-line, so to print them out for
2090 // debugging purposes we need to read the array before executing the
2091 // switch so that we can display all of the possible values.
2092 LLVM_DEBUG({
2093 const ByteCodeField *prevCodeIt = curCodeIt;
2094 LDBG() << " * Value: " << value << "\n * Cases: "
2095 << llvm::interleaved(
2096 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
2097 return read<OperationName>();
2098 }));
2099 curCodeIt = prevCodeIt;
2100 });
2101
2102 // Try to find the switch value within any of the cases.
2103 for (size_t i = 0; i != caseCount; ++i) {
2104 if (read<OperationName>() == value) {
2105 curCodeIt += (caseCount - i - 1);
2106 return selectJump(i + 1);
2107 }
2108 }
2109 selectJump(size_t(0));
2110}
2111
2112void ByteCodeExecutor::executeSwitchResultCount() {
2113 LDBG() << "Executing SwitchResultCount:";
2114 Operation *op = read<Operation *>();
2115 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2116
2117 LDBG() << " * Operation: " << *op;
2118 handleSwitch(op->getNumResults(), cases);
2119}
2120
2121void ByteCodeExecutor::executeSwitchType() {
2122 LDBG() << "Executing SwitchType:";
2123 Type value = read<Type>();
2124 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2125 handleSwitch(value, cases);
2126}
2127
2128void ByteCodeExecutor::executeSwitchTypes() {
2129 LDBG() << "Executing SwitchTypes:";
2130 TypeRange *value = read<TypeRange *>();
2131 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2132 if (!value) {
2133 LDBG() << "Types: <NULL>";
2134 return selectJump(size_t(0));
2135 }
2136 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2137 return value == caseValue.getAsValueRange<TypeAttr>();
2138 });
2139}
2140
2141LogicalResult
2142ByteCodeExecutor::execute(PatternRewriter &rewriter,
2143 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2144 std::optional<Location> mainRewriteLoc) {
2145 while (true) {
2146 // Print the location of the operation being executed.
2147 LDBG() << readInline<Location>();
2148
2149 OpCode opCode = static_cast<OpCode>(read());
2150 switch (opCode) {
2151 case ApplyConstraint:
2152 executeApplyConstraint(rewriter);
2153 break;
2154 case ApplyRewrite:
2155 if (failed(executeApplyRewrite(rewriter)))
2156 return failure();
2157 break;
2158 case AreEqual:
2159 executeAreEqual();
2160 break;
2161 case AreRangesEqual:
2162 executeAreRangesEqual();
2163 break;
2164 case Branch:
2165 executeBranch();
2166 break;
2167 case CheckOperandCount:
2168 executeCheckOperandCount();
2169 break;
2170 case CheckOperationName:
2171 executeCheckOperationName();
2172 break;
2173 case CheckResultCount:
2174 executeCheckResultCount();
2175 break;
2176 case CheckTypes:
2177 executeCheckTypes();
2178 break;
2179 case Continue:
2180 executeContinue();
2181 break;
2182 case CreateConstantTypeRange:
2183 executeCreateConstantTypeRange();
2184 break;
2185 case CreateOperation:
2186 executeCreateOperation(rewriter, *mainRewriteLoc);
2187 break;
2188 case CreateDynamicTypeRange:
2189 executeDynamicCreateRange<Type>("Type");
2190 break;
2191 case CreateDynamicValueRange:
2192 executeDynamicCreateRange<Value>("Value");
2193 break;
2194 case EraseOp:
2195 executeEraseOp(rewriter);
2196 break;
2197 case ExtractOp:
2198 executeExtract<Operation *, std::vector<Operation *>,
2199 PDLValue::Kind::Operation>();
2200 break;
2201 case ExtractType:
2202 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2203 break;
2204 case ExtractValue:
2205 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2206 break;
2207 case Finalize:
2208 executeFinalize();
2209 LDBG() << "";
2210 return success();
2211 case ForEach:
2212 executeForEach();
2213 break;
2214 case GetAttribute:
2215 executeGetAttribute();
2216 break;
2217 case GetAttributeType:
2218 executeGetAttributeType();
2219 break;
2220 case GetDefiningOp:
2221 executeGetDefiningOp();
2222 break;
2223 case GetOperand0:
2224 case GetOperand1:
2225 case GetOperand2:
2226 case GetOperand3: {
2227 unsigned index = opCode - GetOperand0;
2228 LDBG() << "Executing GetOperand" << index << ":";
2229 executeGetOperand(index);
2230 break;
2231 }
2232 case GetOperandN:
2233 LDBG() << "Executing GetOperandN:";
2234 executeGetOperand(read<uint32_t>());
2235 break;
2236 case GetOperands:
2237 executeGetOperands();
2238 break;
2239 case GetResult0:
2240 case GetResult1:
2241 case GetResult2:
2242 case GetResult3: {
2243 unsigned index = opCode - GetResult0;
2244 LDBG() << "Executing GetResult" << index << ":";
2245 executeGetResult(index);
2246 break;
2247 }
2248 case GetResultN:
2249 LDBG() << "Executing GetResultN:";
2250 executeGetResult(read<uint32_t>());
2251 break;
2252 case GetResults:
2253 executeGetResults();
2254 break;
2255 case GetUsers:
2256 executeGetUsers();
2257 break;
2258 case GetValueType:
2259 executeGetValueType();
2260 break;
2261 case GetValueRangeTypes:
2262 executeGetValueRangeTypes();
2263 break;
2264 case IsNotNull:
2265 executeIsNotNull();
2266 break;
2267 case RecordMatch:
2268 assert(matches &&
2269 "expected matches to be provided when executing the matcher");
2270 executeRecordMatch(rewriter, *matches);
2271 break;
2272 case ReplaceOp:
2273 executeReplaceOp(rewriter);
2274 break;
2275 case SwitchAttribute:
2276 executeSwitchAttribute();
2277 break;
2278 case SwitchOperandCount:
2279 executeSwitchOperandCount();
2280 break;
2281 case SwitchOperationName:
2282 executeSwitchOperationName();
2283 break;
2284 case SwitchResultCount:
2285 executeSwitchResultCount();
2286 break;
2287 case SwitchType:
2288 executeSwitchType();
2289 break;
2290 case SwitchTypes:
2291 executeSwitchTypes();
2292 break;
2293 }
2294 LDBG() << "";
2295 }
2296}
2297
2298void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2299 SmallVectorImpl<MatchResult> &matches,
2300 PDLByteCodeMutableState &state) const {
2301 // The first memory slot is always the root operation.
2302 state.memory[0] = op;
2303
2304 // The matcher function always starts at code address 0.
2305 ByteCodeExecutor executor(
2306 matcherByteCode.data(), state.memory, state.opRangeMemory,
2307 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2308 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2309 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2310 constraintFunctions, rewriteFunctions);
2311 LogicalResult executeResult = executor.execute(rewriter, &matches);
2312 (void)executeResult;
2313 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2314
2315 // Order the found matches by benefit.
2316 llvm::stable_sort(matches,
2317 [](const MatchResult &lhs, const MatchResult &rhs) {
2318 return lhs.benefit > rhs.benefit;
2319 });
2320}
2321
2322LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2323 const MatchResult &match,
2324 PDLByteCodeMutableState &state) const {
2325 auto *configSet = match.pattern->getConfigSet();
2326 if (configSet)
2327 configSet->notifyRewriteBegin(rewriter);
2328
2329 // The arguments of the rewrite function are stored at the start of the
2330 // memory buffer.
2331 llvm::copy(match.values, state.memory.begin());
2332
2333 ByteCodeExecutor executor(
2334 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2335 state.opRangeMemory, state.typeRangeMemory,
2336 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2337 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2338 rewriterByteCode, state.currentPatternBenefits, patterns,
2339 constraintFunctions, rewriteFunctions);
2340 LogicalResult result =
2341 executor.execute(rewriter, /*matches=*/nullptr, match.location);
2342
2343 if (configSet)
2344 configSet->notifyRewriteEnd(rewriter);
2345
2346 // If the rewrite failed, check if the pattern rewriter can recover. If it
2347 // can, we can signal to the pattern applicator to keep trying patterns. If it
2348 // doesn't, we need to bail. Bailing here should be fine, given that we have
2349 // no means to propagate such a failure to the user, and it also indicates a
2350 // bug in the user code (i.e. failable rewrites should not be used with
2351 // pattern rewriters that don't support it).
2352 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2353 LDBG() << " and rollback is not supported - aborting";
2354 llvm::report_fatal_error(
2355 "Native PDL Rewrite failed, but the pattern "
2356 "rewriter doesn't support recovery. Failable pattern rewrites should "
2357 "not be used with pattern rewriters that do not support them.");
2358 }
2359 return result;
2360}
for(Operation *op :ops)
return success()
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition ByteCode.cpp:175
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
lhs
ArrayAttr()
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition Attributes.h:73
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
BlockArgListType getArguments()
Definition Block.h:97
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition Liveness.cpp:375
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Types.h:167
type_range getType() const
Type getType() const
Return the type of this value.
Definition Value.h:105
user_iterator user_begin() const
Definition Value.h:216
user_iterator user_end() const
Definition Value.h:217
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition ByteCode.h:235
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition ByteCode.h:234
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition ByteCode.h:248
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition ByteCode.h:247
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition ByteCode.h:251
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
OpFoldResult size