19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/Format.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/InterleavedRange.h"
30#define DEBUG_TYPE "pdl-bytecode"
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
47 if (
ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
56 benefit, ctx, generatedOps);
67 PatternBenefit benefit) {
68 currentPatternBenefits[patternIndex] = benefit;
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
84enum OpCode : ByteCodeField {
106 CreateConstantTypeRange,
110 CreateDynamicTypeRange,
112 CreateDynamicValueRange,
176 std::numeric_limits<ByteCodeField>::max();
187struct ByteCodeLiveRange;
188struct ByteCodeWriter;
191template <
typename T,
typename... Args>
192using has_pointer_traits =
decltype(std::declval<T>().getAsOpaquePointer());
197 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
216 for (
const auto &it : llvm::enumerate(constraintFns))
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
218 for (
const auto &it : llvm::enumerate(rewriteFns))
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
223 void generate(ModuleOp module);
226 ByteCodeField &getMemIndex(Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
233 ByteCodeField &getRangeStorageIndex(Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
241 template <
typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
244 const void *opaqueVal = val.getAsOpaquePointer();
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
261 void generate(Region *region, ByteCodeWriter &writer);
262 void generate(Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
325 ByteCodeField curLoopLevel = 0;
334 std::vector<const void *> &uniquedData;
335 SmallVectorImpl<ByteCodeField> &matcherByteCode;
336 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
337 SmallVectorImpl<PDLByteCodePattern> &
patterns;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
349struct ByteCodeWriter {
350 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &
generator)
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
358 void append(ByteCodeAddr field) {
359 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
369 void append(
Block *successor) {
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
378 void append(SuccessorRange successors) {
379 for (
Block *successor : successors)
384 void appendPDLValueList(OperandRange values) {
385 bytecode.push_back(values.size());
386 for (Value value : values)
387 appendPDLValue(value);
391 void appendPDLValue(Value value) {
392 appendPDLValueKind(value);
397 void appendPDLValueKind(Value value) { appendPDLValueKind(value.
getType()); }
400 void appendPDLValueKind(Type type) {
401 PDLValue::Kind kind =
403 .Case<pdl::AttributeType>(
404 [](Type) {
return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](Type) {
return PDLValue::Kind::Operation; })
407 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
412 .Case<pdl::TypeType>([](Type) {
return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](Type) {
return PDLValue::Kind::Value; });
414 bytecode.push_back(
static_cast<ByteCodeField
>(kind));
419 template <
typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
423 bytecode.push_back(
generator.getMemIndex(value));
427 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
430 bytecode.push_back(llvm::size(range));
431 for (
auto it : range)
436 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
439 append(field2, fields...);
443 template <
typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
457 SmallVectorImpl<ByteCodeField> &bytecode;
465struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
472 void unionWith(
const ByteCodeLiveRange &
rhs) {
473 for (
auto it =
rhs.liveness->begin(), e =
rhs.liveness->end(); it != e;
475 liveness->insert(it.start(), it.stop(), 0);
479 bool overlaps(
const ByteCodeLiveRange &
rhs)
const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *
rhs.liveness)
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
493 std::optional<unsigned> opRangeIndex;
496 std::optional<unsigned> typeRangeIndex;
499 std::optional<unsigned> valueRangeIndex;
503void Generator::generate(ModuleOp module) {
504 auto matcherFunc =
module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule =
module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
512 allocateMemoryIndices(matcherFunc, rewriterModule);
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
516 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
529 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (
unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
536void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
540 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
553 for (BlockArgument arg : rewriterFunc.getArguments())
554 processRewriterValue(arg);
555 rewriterFunc.getBody().walk([&](Operation *op) {
557 processRewriterValue(
result);
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
580 llvm::unique_function<void(Operation *)>
walk = [&](Operation *op) {
581 opToFirstIndex.try_emplace(op, index++);
583 for (
Block &block : region.getBlocks())
584 for (Operation &nested : block)
586 opToLastIndex.try_emplace(op, index++);
591 ByteCodeLiveRange::Allocator allocator;
595 BlockArgument rootOpArg = matcherFunc.getArgument(0);
596 valueToMemIndex[rootOpArg] = 0;
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](
Block *block) {
601 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602 assert(info &&
"expected liveness info for block");
603 auto processValue = [&](Value value, Operation *firstUseOrDef) {
606 if (value == rootOpArg)
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
617 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
629 for (Value liveIn : info->
in()) {
634 if (liveIn.getParentRegion() == block->
getParent())
645 for (Operation &op : *block)
651 std::vector<ByteCodeLiveRange> allocatedIndices;
655 ByteCodeField numIndices = 1;
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
660 for (
auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
665 for (
const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 }
else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 }
else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 }
else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 }
else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
706 memIndex = allocatedIndices.size();
712 LDBG() <<
"Allocated " << allocatedIndices.size() <<
" indices "
713 <<
"(down from initial " << valueDefRanges.size() <<
").";
714 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715 "Ran out of memory for allocated indices");
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
728void Generator::generate(Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (
Block *block : rpot) {
732 blockToAddr.try_emplace(block, matcherByteCode.size());
733 for (Operation &op : *block)
734 generate(&op, writer);
738void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739 LDBG() <<
"Generating bytecode for operation: " << op->
getName();
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->
getLoc());
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](
auto interpOp) { this->generate(interpOp, writer); })
767 .DefaultUnreachable(
"unknown `pdl_interp` operation");
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
778 ResultRange results = op.getResults();
779 writer.append(ByteCodeField(results.size()));
780 for (Value
result : results) {
784 writer.appendPDLValueKind(
result);
787 if (isa<pdl::RangeType>(
result.getType()))
788 writer.append(getRangeStorageIndex(
result));
791 writer.append(op.getSuccessors());
793void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
800 ResultRange results = op.getResults();
801 writer.append(ByteCodeField(results.size()));
802 for (Value
result : results) {
805 writer.appendPDLValueKind(
result);
808 if (isa<pdl::RangeType>(
result.getType()))
809 writer.append(getRangeStorageIndex(
result));
813void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value
lhs = op.getLhs();
815 if (isa<pdl::RangeType>(
lhs.getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(
lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
822 writer.append(OpCode::AreEqual,
lhs, op.getRhs(), op.getSuccessors());
824void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
827void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
832void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
838void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
841 OperationName(op.getName(), ctx), op.getSuccessors());
843void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
849void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
853void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
857void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
861void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
869 OperationName(op.getName(), ctx));
870 writer.appendPDLValueList(op.getInputOperands());
873 OperandRange attributes = op.getInputAttributes();
874 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
875 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
880 if (op.getInferredResultTypes())
883 writer.appendPDLValueList(op.getInputResultTypes());
885void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
897void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
905void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
908void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
911 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
914 .DefaultUnreachable(
"unsupported element type");
915 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
917void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918 writer.append(OpCode::Finalize);
920void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
921 BlockArgument arg = op.getLoopVariable();
922 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923 writer.appendPDLValueKind(arg.
getType());
924 writer.append(curLoopLevel, op.getSuccessor());
926 if (curLoopLevel > maxLoopLevel)
927 maxLoopLevel = curLoopLevel;
928 generate(&op.getRegion(), writer);
931void Generator::generate(pdl_interp::GetAttributeOp op,
932 ByteCodeWriter &writer) {
933 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
936void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937 ByteCodeWriter &writer) {
938 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
940void Generator::generate(pdl_interp::GetDefiningOpOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetDefiningOp, op.getInputOp());
943 writer.appendPDLValue(op.getValue());
945void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946 uint32_t index = op.getIndex();
948 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
950 writer.append(OpCode::GetOperandN, index);
951 writer.append(op.getInputOp(), op.getValue());
953void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954 Value
result = op.getValue();
955 std::optional<uint32_t> index = op.getIndex();
956 writer.append(OpCode::GetOperands,
957 index.value_or(std::numeric_limits<uint32_t>::max()),
959 if (isa<pdl::RangeType>(
result.getType()))
960 writer.append(getRangeStorageIndex(
result));
962 writer.append(std::numeric_limits<ByteCodeField>::max());
965void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966 uint32_t index = op.getIndex();
968 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
970 writer.append(OpCode::GetResultN, index);
971 writer.append(op.getInputOp(), op.getValue());
973void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974 Value
result = op.getValue();
975 std::optional<uint32_t> index = op.getIndex();
976 writer.append(OpCode::GetResults,
977 index.value_or(std::numeric_limits<uint32_t>::max()),
979 if (isa<pdl::RangeType>(
result.getType()))
980 writer.append(getRangeStorageIndex(
result));
982 writer.append(std::numeric_limits<ByteCodeField>::max());
985void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986 Value operations = op.getOperations();
987 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988 writer.append(OpCode::GetUsers, operations, rangeIndex);
989 writer.appendPDLValue(op.getValue());
991void Generator::generate(pdl_interp::GetValueTypeOp op,
992 ByteCodeWriter &writer) {
993 if (isa<pdl::RangeType>(op.getType())) {
994 Value
result = op.getResult();
995 writer.append(OpCode::GetValueRangeTypes,
result,
996 getRangeStorageIndex(
result), op.getValue());
998 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1004void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005 ByteCodeField patternIndex =
patterns.size();
1006 patterns.emplace_back(PDLByteCodePattern::create(
1007 op, configMap.lookup(op),
1008 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009 writer.append(OpCode::RecordMatch, patternIndex,
1010 SuccessorRange(op.getOperation()), op.getMatchedOps());
1011 writer.appendPDLValueList(op.getInputs());
1013void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014 writer.append(OpCode::ReplaceOp, op.getInputOp());
1015 writer.appendPDLValueList(op.getReplValues());
1017void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018 ByteCodeWriter &writer) {
1019 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020 op.getCaseValuesAttr(), op.getSuccessors());
1022void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023 ByteCodeWriter &writer) {
1024 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025 op.getCaseValuesAttr(), op.getSuccessors());
1027void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028 ByteCodeWriter &writer) {
1029 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1030 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1032 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033 op.getSuccessors());
1035void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036 ByteCodeWriter &writer) {
1037 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038 op.getCaseValuesAttr(), op.getSuccessors());
1040void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042 op.getSuccessors());
1044void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1053PDLByteCode::PDLByteCode(
1054 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056 llvm::StringMap<PDLConstraintFunction> constraintFns,
1057 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058 : configs(std::move(configs)) {
1059 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1060 rewriterByteCode,
patterns, maxValueMemoryIndex,
1061 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062 maxLoopLevel, constraintFns, rewriteFns, configMap);
1066 for (
auto &it : constraintFns)
1067 constraintFunctions.push_back(std::move(it.second));
1068 for (
auto &it : rewriteFns)
1069 rewriteFunctions.push_back(std::move(it.second));
1075 state.memory.resize(maxValueMemoryIndex,
nullptr);
1076 state.opRangeMemory.resize(maxOpRangeCount);
1077 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1078 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1079 state.loopIndex.resize(maxLoopLevel, 0);
1080 state.currentPatternBenefits.reserve(
patterns.size());
1081 for (
const PDLByteCodePattern &pattern :
patterns)
1082 state.currentPatternBenefits.push_back(pattern.getBenefit());
1093class ByteCodeRewriteResultList :
public PDLResultList {
1095 ByteCodeRewriteResultList(
unsigned maxNumResults)
1096 : PDLResultList(maxNumResults) {}
1099 MutableArrayRef<PDLValue> getResults() {
return results; }
1102 MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() {
1103 return allocatedTypeRanges;
1107 MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() {
1108 return allocatedValueRanges;
1113class ByteCodeExecutor {
1115 ByteCodeExecutor(
const ByteCodeField *curCodeIt,
1116 MutableArrayRef<const void *> memory,
1117 MutableArrayRef<std::vector<Operation *>> opRangeMemory,
1118 MutableArrayRef<TypeRange> typeRangeMemory,
1119 std::vector<std::vector<Type>> &allocatedTypeRangeMemory,
1120 MutableArrayRef<ValueRange> valueRangeMemory,
1121 std::vector<std::vector<Value>> &allocatedValueRangeMemory,
1122 MutableArrayRef<unsigned> loopIndex,
1123 ArrayRef<const void *> uniquedMemory,
1124 ArrayRef<ByteCodeField> code,
1125 ArrayRef<PatternBenefit> currentPatternBenefits,
1126 ArrayRef<PDLByteCodePattern>
patterns,
1127 ArrayRef<PDLConstraintFunction> constraintFunctions,
1128 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130 typeRangeMemory(typeRangeMemory),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132 valueRangeMemory(valueRangeMemory),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory),
1134 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1136 constraintFunctions(constraintFunctions),
1137 rewriteFunctions(rewriteFunctions) {}
1143 execute(PatternRewriter &rewriter,
1144 SmallVectorImpl<PDLByteCode::MatchResult> *matches =
nullptr,
1145 std::optional<Location> mainRewriteLoc = {});
1149 void executeApplyConstraint(PatternRewriter &rewriter);
1150 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1160 void executeCreateOperation(PatternRewriter &rewriter,
1161 Location mainRewriteLoc);
1162 template <
typename T>
1163 void executeDynamicCreateRange(StringRef type);
1164 void executeEraseOp(PatternRewriter &rewriter);
1165 template <
typename T,
typename Range, PDLValue::Kind kind>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(
unsigned index);
1173 void executeGetOperands();
1174 void executeGetResult(
unsigned index);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1180 void executeRecordMatch(PatternRewriter &rewriter,
1181 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182 void executeReplaceOp(PatternRewriter &rewriter);
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList &results,
1190 unsigned numResults,
1191 LogicalResult &rewriteResult);
1194 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1198 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1199 curCodeIt = resumeCodeIt.pop_back_val();
1203 const ByteCodeField *getPrevCodeIt()
const {
1206 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1210 return curCodeIt - 1;
1216 template <
typename T = ByteCodeField>
1217 T read(
size_t skipN = 0) {
1219 return readImpl<T>();
1221 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1224 template <
typename ValueT,
typename T>
1225 void readList(SmallVectorImpl<T> &list) {
1227 for (
unsigned i = 0, e = read(); i != e; ++i)
1228 list.push_back(read<ValueT>());
1233 void readList(SmallVectorImpl<Type> &list) {
1234 for (
unsigned i = 0, e = read(); i != e; ++i) {
1235 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1236 list.push_back(read<Type>());
1238 TypeRange *values = read<TypeRange *>();
1239 list.append(values->begin(), values->end());
1243 void readList(SmallVectorImpl<Value> &list) {
1244 for (
unsigned i = 0, e = read(); i != e; ++i) {
1245 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1246 list.push_back(read<Value>());
1249 list.append(values->begin(), values->end());
1255 template <
typename T>
1256 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1258 const void *pointer;
1259 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1260 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1261 return T::getFromOpaquePointer(pointer);
1264 void skip(
size_t skipN) { curCodeIt += skipN; }
1267 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1269 void selectJump(
size_t destIndex) {
1270 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1274 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1275 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1276 LDBG() <<
"Switch operation:\n * Value: " << value
1277 <<
"\n * Cases: " << llvm::interleaved(cases);
1281 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1282 if (cmp(*it, value))
1283 return selectJump(
size_t((it - cases.begin()) + 1));
1284 selectJump(
size_t(0));
1288 void storeToMemory(
unsigned index,
const void *value) {
1289 memory[index] = value;
1293 template <
typename T>
1294 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1295 storeToMemory(
unsigned index, T value) {
1296 memory[index] = value.getAsOpaquePointer();
1301 template <
typename T>
1302 const void *readFromMemory() {
1303 size_t index = *curCodeIt++;
1308 index < memory.size())
1309 return memory[index];
1312 return uniquedMemory[index - memory.size()];
1314 template <
typename T>
1315 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1316 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1318 template <
typename T>
1319 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1322 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1324 template <
typename T>
1325 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1326 switch (read<PDLValue::Kind>()) {
1327 case PDLValue::Kind::Attribute:
1328 return read<Attribute>();
1329 case PDLValue::Kind::Operation:
1330 return read<Operation *>();
1331 case PDLValue::Kind::Type:
1332 return read<Type>();
1333 case PDLValue::Kind::Value:
1334 return read<Value>();
1335 case PDLValue::Kind::TypeRange:
1336 return read<TypeRange *>();
1337 case PDLValue::Kind::ValueRange:
1338 return read<ValueRange *>();
1340 llvm_unreachable(
"unhandled PDLValue::Kind");
1342 template <
typename T>
1343 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1344 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1345 "unexpected ByteCode address size");
1347 std::memcpy(&
result, curCodeIt,
sizeof(ByteCodeAddr));
1351 template <
typename T>
1352 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1353 return *curCodeIt++;
1355 template <
typename T>
1356 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1357 return static_cast<PDLValue::Kind
>(readImpl<ByteCodeField>());
1362 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1363 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1364 unsigned rangeIndex) {
1366 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1368 if (range.empty()) {
1369 rangeMemory[rangeIndex] = {};
1373 allocatedRangeMemory.emplace_back(range.begin(), range.end());
1374 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1376 memory[memIndex] = &rangeMemory[rangeIndex];
1380 if constexpr (std::is_same_v<T, Type>) {
1381 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1382 }
else if constexpr (std::is_same_v<T, Value>) {
1383 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1385 llvm_unreachable(
"unhandled range type");
1390 const ByteCodeField *curCodeIt;
1393 SmallVector<const ByteCodeField *> resumeCodeIt;
1396 MutableArrayRef<const void *> memory;
1397 MutableArrayRef<std::vector<Operation *>> opRangeMemory;
1398 MutableArrayRef<TypeRange> typeRangeMemory;
1399 std::vector<std::vector<Type>> &allocatedTypeRangeMemory;
1400 MutableArrayRef<ValueRange> valueRangeMemory;
1401 std::vector<std::vector<Value>> &allocatedValueRangeMemory;
1404 MutableArrayRef<unsigned> loopIndex;
1407 ArrayRef<const void *> uniquedMemory;
1408 ArrayRef<ByteCodeField> code;
1409 ArrayRef<PatternBenefit> currentPatternBenefits;
1410 ArrayRef<PDLByteCodePattern>
patterns;
1411 ArrayRef<PDLConstraintFunction> constraintFunctions;
1412 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1416void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1417 LDBG() <<
"Executing ApplyConstraint:";
1418 ByteCodeField fun_idx = read();
1419 SmallVector<PDLValue, 16> args;
1420 readList<PDLValue>(args);
1422 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1424 ByteCodeField isNegated = read();
1425 LDBG() <<
" * isNegated: " << isNegated;
1427 ByteCodeField numResults = read();
1428 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1429 ByteCodeRewriteResultList results(numResults);
1430 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1431 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1432 if (succeeded(rewriteResult)) {
1433 LDBG() <<
" * Constraint succeeded, results: "
1434 << llvm::interleaved(constraintResults);
1436 LDBG() <<
" * Constraint failed";
1438 assert((
failed(rewriteResult) || constraintResults.size() == numResults) &&
1439 "native PDL rewrite function succeeded but returned "
1440 "unexpected number of results");
1441 processNativeFunResults(results, numResults, rewriteResult);
1444 selectJump(isNegated != succeeded(rewriteResult));
1447LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1448 LDBG() <<
"Executing ApplyRewrite:";
1449 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1450 SmallVector<PDLValue, 16> args;
1451 readList<PDLValue>(args);
1453 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1456 ByteCodeField numResults = read();
1457 ByteCodeRewriteResultList results(numResults);
1458 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1460 assert(results.getResults().size() == numResults &&
1461 "native PDL rewrite function returned unexpected number of results");
1463 processNativeFunResults(results, numResults, rewriteResult);
1465 if (
failed(rewriteResult)) {
1466 LDBG() <<
" - Failed";
1472void ByteCodeExecutor::processNativeFunResults(
1473 ByteCodeRewriteResultList &results,
unsigned numResults,
1474 LogicalResult &rewriteResult) {
1475 if (
failed(rewriteResult)) {
1478 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1479 const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1480 if (resultKind == PDLValue::Kind::TypeRange ||
1481 resultKind == PDLValue::Kind::ValueRange) {
1491 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1492 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1494 PDLValue
result = results.getResults()[resultIdx];
1495 LDBG() <<
" * Result: " <<
result;
1496 assert(
result.getKind() == resultKind &&
1497 "native PDL rewrite function returned an unexpected type of "
1501 if (std::optional<TypeRange> typeRange =
result.dyn_cast<
TypeRange>()) {
1502 unsigned rangeIndex = read();
1503 typeRangeMemory[rangeIndex] = *typeRange;
1504 memory[read()] = &typeRangeMemory[rangeIndex];
1505 }
else if (std::optional<ValueRange> valueRange =
1507 unsigned rangeIndex = read();
1508 valueRangeMemory[rangeIndex] = *valueRange;
1509 memory[read()] = &valueRangeMemory[rangeIndex];
1511 memory[read()] =
result.getAsOpaquePointer();
1516 for (
auto &it : results.getAllocatedTypeRanges())
1517 allocatedTypeRangeMemory.push_back(std::move(it));
1518 for (
auto &it : results.getAllocatedValueRanges())
1519 allocatedValueRangeMemory.push_back(std::move(it));
1522void ByteCodeExecutor::executeAreEqual() {
1523 LDBG() <<
"Executing AreEqual:";
1524 const void *
lhs = read<const void *>();
1525 const void *
rhs = read<const void *>();
1527 LDBG() <<
" * " <<
lhs <<
" == " <<
rhs;
1531void ByteCodeExecutor::executeAreRangesEqual() {
1532 LDBG() <<
"Executing AreRangesEqual:";
1533 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1534 const void *
lhs = read<const void *>();
1535 const void *
rhs = read<const void *>();
1537 switch (valueKind) {
1538 case PDLValue::Kind::TypeRange: {
1541 LDBG() <<
" * " <<
lhs <<
" == " <<
rhs;
1542 selectJump(*lhsRange == *rhsRange);
1545 case PDLValue::Kind::ValueRange: {
1546 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(
lhs);
1547 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(
rhs);
1548 LDBG() <<
" * " <<
lhs <<
" == " <<
rhs;
1549 selectJump(*lhsRange == *rhsRange);
1553 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1557void ByteCodeExecutor::executeBranch() {
1558 LDBG() <<
"Executing Branch";
1559 curCodeIt = &code[read<ByteCodeAddr>()];
1562void ByteCodeExecutor::executeCheckOperandCount() {
1563 LDBG() <<
"Executing CheckOperandCount:";
1564 Operation *op = read<Operation *>();
1565 uint32_t expectedCount = read<uint32_t>();
1566 bool compareAtLeast = read();
1569 <<
"\n * Expected: " << expectedCount
1570 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1577void ByteCodeExecutor::executeCheckOperationName() {
1578 LDBG() <<
"Executing CheckOperationName:";
1579 Operation *op = read<Operation *>();
1580 OperationName expectedName = read<OperationName>();
1582 LDBG() <<
" * Found: \"" << op->
getName() <<
"\"\n * Expected: \""
1583 << expectedName <<
"\"";
1584 selectJump(op->
getName() == expectedName);
1587void ByteCodeExecutor::executeCheckResultCount() {
1588 LDBG() <<
"Executing CheckResultCount:";
1589 Operation *op = read<Operation *>();
1590 uint32_t expectedCount = read<uint32_t>();
1591 bool compareAtLeast = read();
1594 <<
"\n * Expected: " << expectedCount
1595 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1602void ByteCodeExecutor::executeCheckTypes() {
1603 LDBG() <<
"Executing AreEqual:";
1605 Attribute
rhs = read<Attribute>();
1606 LDBG() <<
" * " <<
lhs <<
" == " <<
rhs;
1608 selectJump(*
lhs == cast<ArrayAttr>(
rhs).getAsValueRange<TypeAttr>());
1611void ByteCodeExecutor::executeContinue() {
1612 ByteCodeField level = read();
1613 LDBG() <<
"Executing Continue\n * Level: " << level;
1618void ByteCodeExecutor::executeCreateConstantTypeRange() {
1619 LDBG() <<
"Executing CreateConstantTypeRange:";
1620 unsigned memIndex = read();
1621 unsigned rangeIndex = read();
1622 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1624 LDBG() <<
" * Types: " << typesAttr;
1625 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1629void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1630 Location mainRewriteLoc) {
1631 LDBG() <<
"Executing CreateOperation:";
1633 unsigned memIndex = read();
1634 OperationState state(mainRewriteLoc, read<OperationName>());
1635 readList(state.operands);
1636 for (
unsigned i = 0, e = read(); i != e; ++i) {
1637 StringAttr name = read<StringAttr>();
1638 if (Attribute attr = read<Attribute>())
1639 state.addAttribute(name, attr);
1644 unsigned numResults = read();
1646 InferTypeOpInterface::Concept *inferInterface =
1647 state.name.getInterface<InferTypeOpInterface>();
1648 assert(inferInterface &&
1649 "expected operation to provide InferTypeOpInterface");
1652 if (
failed(inferInterface->inferReturnTypes(
1653 state.getContext(), state.location, state.operands,
1654 state.attributes.getDictionary(state.getContext()),
1655 state.getRawProperties(), state.regions, state.types)))
1659 for (
unsigned i = 0; i != numResults; ++i) {
1660 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1661 state.types.push_back(read<Type>());
1663 TypeRange *resultTypes = read<TypeRange *>();
1664 state.types.append(resultTypes->begin(), resultTypes->end());
1669 Operation *resultOp = rewriter.
create(state);
1670 memory[memIndex] = resultOp;
1672 LDBG() <<
" * Attributes: "
1673 << state.attributes.getDictionary(state.getContext())
1674 <<
"\n * Operands: " << llvm::interleaved(state.operands)
1675 <<
"\n * Result Types: " << llvm::interleaved(state.types)
1676 <<
"\n * Result: " << *resultOp;
1679template <
typename T>
1680void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1681 LDBG() <<
"Executing CreateDynamic" << type <<
"Range:";
1682 unsigned memIndex = read();
1683 unsigned rangeIndex = read();
1684 SmallVector<T> values;
1687 LDBG() <<
" * " << type <<
"s: " << llvm::interleaved(values);
1689 assignRangeToMemory(values, memIndex, rangeIndex);
1692void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1693 LDBG() <<
"Executing EraseOp:";
1694 Operation *op = read<Operation *>();
1696 LDBG() <<
" * Operation: " << *op;
1700template <
typename T,
typename Range, PDLValue::Kind kind>
1701void ByteCodeExecutor::executeExtract() {
1702 LDBG() <<
"Executing Extract" << kind <<
":";
1703 Range *range = read<Range *>();
1704 unsigned index = read<uint32_t>();
1705 unsigned memIndex = read();
1708 memory[memIndex] =
nullptr;
1712 T
result = index < range->
size() ? (*range)[index] : T();
1713 LDBG() <<
" * " << kind <<
"s(" << range->
size() <<
")";
1714 LDBG() <<
" * Index: " << index;
1715 LDBG() <<
" * Result: " <<
result;
1716 storeToMemory(memIndex,
result);
1719void ByteCodeExecutor::executeFinalize() { LDBG() <<
"Executing Finalize"; }
1721void ByteCodeExecutor::executeForEach() {
1722 LDBG() <<
"Executing ForEach:";
1723 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1724 unsigned rangeIndex = read();
1725 unsigned memIndex = read();
1726 const void *value =
nullptr;
1728 switch (read<PDLValue::Kind>()) {
1729 case PDLValue::Kind::Operation: {
1730 unsigned &index = loopIndex[read()];
1731 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1732 assert(index <= array.size() &&
"iterated past the end");
1733 if (index < array.size()) {
1734 LDBG() <<
" * Result: " << array[index];
1735 value = array[index];
1739 LDBG() <<
" * Done";
1741 selectJump(
size_t(0));
1745 llvm_unreachable(
"unexpected `ForEach` value kind");
1749 memory[memIndex] = value;
1750 pushCodeIt(prevCodeIt);
1753 read<ByteCodeAddr>();
1756void ByteCodeExecutor::executeGetAttribute() {
1757 LDBG() <<
"Executing GetAttribute:";
1758 unsigned memIndex = read();
1759 Operation *op = read<Operation *>();
1760 StringAttr attrName = read<StringAttr>();
1761 Attribute attr = op->
getAttr(attrName);
1763 LDBG() <<
" * Operation: " << *op <<
"\n * Attribute: " << attrName
1764 <<
"\n * Result: " << attr;
1768void ByteCodeExecutor::executeGetAttributeType() {
1769 LDBG() <<
"Executing GetAttributeType:";
1770 unsigned memIndex = read();
1771 Attribute attr = read<Attribute>();
1773 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1774 type = typedAttr.getType();
1776 LDBG() <<
" * Attribute: " << attr <<
"\n * Result: " << type;
1780void ByteCodeExecutor::executeGetDefiningOp() {
1781 LDBG() <<
"Executing GetDefiningOp:";
1782 unsigned memIndex = read();
1783 Operation *op =
nullptr;
1784 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1785 Value value = read<Value>();
1788 LDBG() <<
" * Value: " << value;
1791 if (values && !values->empty()) {
1792 op = values->front().getDefiningOp();
1794 LDBG() <<
" * Values: " << values;
1797 LDBG() <<
" * Result: " << op;
1798 memory[memIndex] = op;
1801void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1802 Operation *op = read<Operation *>();
1803 unsigned memIndex = read();
1807 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1808 <<
"\n * Result: " << operand;
1815template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1818 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1822 if (
index == std::numeric_limits<uint32_t>::max()) {
1823 LDBG() <<
" * Getting all values";
1827 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1828 LDBG() <<
" * Extracting values from `" << attrSizedSegments <<
"`";
1831 if (!segmentAttr || segmentAttr.asArrayRef().size() <=
index)
1835 unsigned startIndex = llvm::sum_of(segments.take_front(
index));
1836 values = values.slice(startIndex, *std::next(segments.begin(),
index));
1838 LDBG() <<
" * Extracting range[" << startIndex <<
", "
1839 << *std::next(segments.begin(),
index) <<
"]";
1845 }
else if (values.size() >=
index) {
1846 LDBG() <<
" * Treating values as trailing variadic range";
1847 values = values.drop_front(
index);
1855 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1856 valueRangeMemory[rangeIndex] = values;
1857 return &valueRangeMemory[rangeIndex];
1861 return values.size() != 1 ?
nullptr : values.front().getAsOpaquePointer();
1864void ByteCodeExecutor::executeGetOperands() {
1865 LDBG() <<
"Executing GetOperands:";
1866 unsigned index = read<uint32_t>();
1867 Operation *op = read<Operation *>();
1868 ByteCodeField rangeIndex = read();
1871 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1874 LDBG() <<
" * Invalid operand range";
1878void ByteCodeExecutor::executeGetResult(
unsigned index) {
1879 Operation *op = read<Operation *>();
1880 unsigned memIndex = read();
1884 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1885 <<
"\n * Result: " <<
result;
1886 memory[memIndex] =
result.getAsOpaquePointer();
1889void ByteCodeExecutor::executeGetResults() {
1890 LDBG() <<
"Executing GetResults:";
1891 unsigned index = read<uint32_t>();
1892 Operation *op = read<Operation *>();
1893 ByteCodeField rangeIndex = read();
1896 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1899 LDBG() <<
" * Invalid result range";
1903void ByteCodeExecutor::executeGetUsers() {
1904 LDBG() <<
"Executing GetUsers:";
1905 unsigned memIndex = read();
1906 unsigned rangeIndex = read();
1907 std::vector<Operation *> &range = opRangeMemory[rangeIndex];
1908 memory[memIndex] = ⦥
1911 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1913 Value value = read<Value>();
1916 LDBG() <<
" * Value: " << value;
1924 LDBG() <<
" * Values (" << values->size()
1925 <<
"): " << llvm::interleaved(*values);
1927 for (Value value : *values)
1931 LDBG() <<
" * Result: " << range.
size() <<
" operations";
1934void ByteCodeExecutor::executeGetValueType() {
1935 LDBG() <<
"Executing GetValueType:";
1936 unsigned memIndex = read();
1937 Value value = read<Value>();
1938 Type type = value ? value.
getType() : Type();
1940 LDBG() <<
" * Value: " << value <<
"\n * Result: " << type;
1944void ByteCodeExecutor::executeGetValueRangeTypes() {
1945 LDBG() <<
"Executing GetValueRangeTypes:";
1946 unsigned memIndex = read();
1947 unsigned rangeIndex = read();
1950 LDBG() <<
" * Values: <NULL>";
1951 memory[memIndex] =
nullptr;
1955 LDBG() <<
" * Values (" << values->size()
1956 <<
"): " << llvm::interleaved(*values)
1957 <<
"\n * Result: " << llvm::interleaved(values->
getType());
1958 typeRangeMemory[rangeIndex] = values->
getType();
1959 memory[memIndex] = &typeRangeMemory[rangeIndex];
1962void ByteCodeExecutor::executeIsNotNull() {
1963 LDBG() <<
"Executing IsNotNull:";
1964 const void *value = read<const void *>();
1966 LDBG() <<
" * Value: " << value;
1967 selectJump(value !=
nullptr);
1970void ByteCodeExecutor::executeRecordMatch(
1971 PatternRewriter &rewriter,
1972 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1973 LDBG() <<
"Executing RecordMatch:";
1974 unsigned patternIndex = read();
1975 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1976 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1981 LDBG() <<
" * Benefit: Impossible To Match";
1990 unsigned numMatchLocs = read();
1991 SmallVector<Location, 4> matchLocs;
1992 matchLocs.reserve(numMatchLocs);
1993 for (
unsigned i = 0; i != numMatchLocs; ++i)
1994 matchLocs.push_back(read<Operation *>()->getLoc());
1995 Location matchLoc = rewriter.
getFusedLoc(matchLocs);
1997 LDBG() <<
" * Benefit: " << benefit.
getBenefit();
1998 LDBG() <<
" * Location: " << matchLoc;
1999 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2005 unsigned numInputs = read();
2006 match.values.reserve(numInputs);
2007 match.typeRangeValues.reserve(numInputs);
2008 match.valueRangeValues.reserve(numInputs);
2009 for (
unsigned i = 0; i < numInputs; ++i) {
2010 switch (read<PDLValue::Kind>()) {
2011 case PDLValue::Kind::TypeRange:
2012 match.typeRangeValues.push_back(*read<TypeRange *>());
2013 match.values.push_back(&match.typeRangeValues.back());
2015 case PDLValue::Kind::ValueRange:
2016 match.valueRangeValues.push_back(*read<ValueRange *>());
2017 match.values.push_back(&match.valueRangeValues.back());
2020 match.values.push_back(read<const void *>());
2027void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2028 LDBG() <<
"Executing ReplaceOp:";
2029 Operation *op = read<Operation *>();
2030 SmallVector<Value, 16> args;
2033 LDBG() <<
" * Operation: " << *op
2034 <<
"\n * Values: " << llvm::interleaved(args);
2038void ByteCodeExecutor::executeSwitchAttribute() {
2039 LDBG() <<
"Executing SwitchAttribute:";
2040 Attribute value = read<Attribute>();
2042 handleSwitch(value, cases);
2045void ByteCodeExecutor::executeSwitchOperandCount() {
2046 LDBG() <<
"Executing SwitchOperandCount:";
2047 Operation *op = read<Operation *>();
2048 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2050 LDBG() <<
" * Operation: " << *op;
2054void ByteCodeExecutor::executeSwitchOperationName() {
2055 LDBG() <<
"Executing SwitchOperationName:";
2056 OperationName value = read<Operation *>()->getName();
2057 size_t caseCount = read();
2063 const ByteCodeField *prevCodeIt = curCodeIt;
2064 LDBG() <<
" * Value: " << value <<
"\n * Cases: "
2065 << llvm::interleaved(
2066 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](
size_t) {
2067 return read<OperationName>();
2069 curCodeIt = prevCodeIt;
2073 for (
size_t i = 0; i != caseCount; ++i) {
2074 if (read<OperationName>() == value) {
2075 curCodeIt += (caseCount - i - 1);
2076 return selectJump(i + 1);
2079 selectJump(
size_t(0));
2082void ByteCodeExecutor::executeSwitchResultCount() {
2083 LDBG() <<
"Executing SwitchResultCount:";
2084 Operation *op = read<Operation *>();
2085 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2087 LDBG() <<
" * Operation: " << *op;
2091void ByteCodeExecutor::executeSwitchType() {
2092 LDBG() <<
"Executing SwitchType:";
2093 Type value = read<Type>();
2094 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2095 handleSwitch(value, cases);
2098void ByteCodeExecutor::executeSwitchTypes() {
2099 LDBG() <<
"Executing SwitchTypes:";
2101 auto cases = read<ArrayAttr>().getAsRange<
ArrayAttr>();
2103 LDBG() <<
"Types: <NULL>";
2104 return selectJump(
size_t(0));
2107 return value == caseValue.getAsValueRange<TypeAttr>();
2112ByteCodeExecutor::execute(PatternRewriter &rewriter,
2113 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2114 std::optional<Location> mainRewriteLoc) {
2117 LDBG() << readInline<Location>();
2119 OpCode opCode =
static_cast<OpCode
>(read());
2121 case ApplyConstraint:
2122 executeApplyConstraint(rewriter);
2125 if (
failed(executeApplyRewrite(rewriter)))
2131 case AreRangesEqual:
2132 executeAreRangesEqual();
2137 case CheckOperandCount:
2138 executeCheckOperandCount();
2140 case CheckOperationName:
2141 executeCheckOperationName();
2143 case CheckResultCount:
2144 executeCheckResultCount();
2147 executeCheckTypes();
2152 case CreateConstantTypeRange:
2153 executeCreateConstantTypeRange();
2155 case CreateOperation:
2156 executeCreateOperation(rewriter, *mainRewriteLoc);
2158 case CreateDynamicTypeRange:
2159 executeDynamicCreateRange<Type>(
"Type");
2161 case CreateDynamicValueRange:
2162 executeDynamicCreateRange<Value>(
"Value");
2165 executeEraseOp(rewriter);
2168 executeExtract<Operation *, std::vector<Operation *>,
2169 PDLValue::Kind::Operation>();
2172 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2175 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2185 executeGetAttribute();
2187 case GetAttributeType:
2188 executeGetAttributeType();
2191 executeGetDefiningOp();
2197 unsigned index = opCode - GetOperand0;
2198 LDBG() <<
"Executing GetOperand" << index <<
":";
2199 executeGetOperand(index);
2203 LDBG() <<
"Executing GetOperandN:";
2204 executeGetOperand(read<uint32_t>());
2207 executeGetOperands();
2213 unsigned index = opCode - GetResult0;
2214 LDBG() <<
"Executing GetResult" << index <<
":";
2215 executeGetResult(index);
2219 LDBG() <<
"Executing GetResultN:";
2220 executeGetResult(read<uint32_t>());
2223 executeGetResults();
2229 executeGetValueType();
2231 case GetValueRangeTypes:
2232 executeGetValueRangeTypes();
2239 "expected matches to be provided when executing the matcher");
2240 executeRecordMatch(rewriter, *matches);
2243 executeReplaceOp(rewriter);
2245 case SwitchAttribute:
2246 executeSwitchAttribute();
2248 case SwitchOperandCount:
2249 executeSwitchOperandCount();
2251 case SwitchOperationName:
2252 executeSwitchOperationName();
2254 case SwitchResultCount:
2255 executeSwitchResultCount();
2258 executeSwitchType();
2261 executeSwitchTypes();
2269 SmallVectorImpl<MatchResult> &matches,
2272 state.memory[0] = op;
2275 ByteCodeExecutor executor(
2276 matcherByteCode.data(), state.memory, state.opRangeMemory,
2277 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2278 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2279 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2280 constraintFunctions, rewriteFunctions);
2281 LogicalResult executeResult = executor.execute(rewriter, &matches);
2282 (void)executeResult;
2283 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2286 llvm::stable_sort(matches,
2288 return lhs.benefit >
rhs.benefit;
2293 const MatchResult &match,
2295 auto *configSet =
match.pattern->getConfigSet();
2297 configSet->notifyRewriteBegin(rewriter);
2301 llvm::copy(
match.values, state.memory.begin());
2303 ByteCodeExecutor executor(
2304 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2305 state.opRangeMemory, state.typeRangeMemory,
2306 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2307 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2308 rewriterByteCode, state.currentPatternBenefits,
patterns,
2309 constraintFunctions, rewriteFunctions);
2311 executor.execute(rewriter,
nullptr,
match.location);
2314 configSet->notifyRewriteEnd(rewriter);
2323 LDBG() <<
" and rollback is not supported - aborting";
2324 llvm::report_fatal_error(
2325 "Native PDL Rewrite failed, but the pattern "
2326 "rewriter doesn't support recovery. Failable pattern rewrites should "
2327 "not be used with pattern rewriters that do not support them.");
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
type_range getType() const
Type getType() const
Return the type of this value.
user_iterator user_begin() const
user_iterator user_end() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap