19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
54 benefit, ctx, generatedOps);
66 currentPatternBenefits[patternIndex] = benefit;
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
82 enum OpCode : ByteCodeField {
104 CreateConstantTypeRange,
108 CreateDynamicTypeRange,
110 CreateDynamicValueRange,
185 struct ByteCodeLiveRange;
186 struct ByteCodeWriter;
189 template <
typename T,
typename... Args>
190 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
195 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
199 ByteCodeField &maxValueMemoryIndex,
200 ByteCodeField &maxOpRangeMemoryIndex,
201 ByteCodeField &maxTypeRangeMemoryIndex,
202 ByteCodeField &maxValueRangeMemoryIndex,
203 ByteCodeField &maxLoopLevel,
204 llvm::StringMap<PDLConstraintFunction> &constraintFns,
205 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
207 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
209 maxValueMemoryIndex(maxValueMemoryIndex),
210 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
211 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
212 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
213 maxLoopLevel(maxLoopLevel), configMap(configMap) {
215 constraintToMemIndex.try_emplace(it.value().first(), it.index());
217 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
221 void generate(ModuleOp module);
224 ByteCodeField &getMemIndex(
Value value) {
225 assert(valueToMemIndex.count(value) &&
226 "expected memory index to be assigned");
227 return valueToMemIndex[value];
231 ByteCodeField &getRangeStorageIndex(
Value value) {
232 assert(valueToRangeIndex.count(value) &&
233 "expected range index to be assigned");
234 return valueToRangeIndex[value];
239 template <
typename T>
240 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
242 const void *opaqueVal = val.getAsOpaquePointer();
245 auto it = uniquedDataToMemIndex.try_emplace(
246 opaqueVal, maxValueMemoryIndex + uniquedData.size());
248 uniquedData.push_back(opaqueVal);
249 return it.first->second;
255 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
256 ModuleOp rewriterModule);
259 void generate(
Region *region, ByteCodeWriter &writer);
260 void generate(
Operation *op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
308 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
312 llvm::StringMap<ByteCodeField> constraintToMemIndex;
316 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
323 ByteCodeField curLoopLevel = 0;
332 std::vector<const void *> &uniquedData;
336 ByteCodeField &maxValueMemoryIndex;
337 ByteCodeField &maxOpRangeMemoryIndex;
338 ByteCodeField &maxTypeRangeMemoryIndex;
339 ByteCodeField &maxValueRangeMemoryIndex;
340 ByteCodeField &maxLoopLevel;
347 struct ByteCodeWriter {
352 void append(ByteCodeField field) { bytecode.push_back(field); }
353 void append(OpCode opCode) { bytecode.push_back(opCode); }
356 void append(ByteCodeAddr field) {
357 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
358 "unexpected ByteCode address size");
360 ByteCodeField fieldParts[2];
361 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
362 bytecode.append({fieldParts[0], fieldParts[1]});
367 void append(
Block *successor) {
370 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
371 append(ByteCodeAddr(0));
377 for (
Block *successor : successors)
383 bytecode.push_back(values.size());
384 for (
Value value : values)
385 appendPDLValue(value);
389 void appendPDLValue(
Value value) {
390 appendPDLValueKind(value);
395 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
398 void appendPDLValueKind(
Type type) {
401 .Case<pdl::AttributeType>(
402 [](
Type) {
return PDLValue::Kind::Attribute; })
403 .Case<pdl::OperationType>(
404 [](
Type) {
return PDLValue::Kind::Operation; })
405 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
406 if (isa<pdl::TypeType>(rangeTy.getElementType()))
407 return PDLValue::Kind::TypeRange;
408 return PDLValue::Kind::ValueRange;
410 .Case<pdl::TypeType>([](
Type) {
return PDLValue::Kind::Type; })
411 .Case<pdl::ValueType>([](
Type) {
return PDLValue::Kind::Value; });
412 bytecode.push_back(
static_cast<ByteCodeField
>(
kind));
417 template <
typename T>
418 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
419 std::is_pointer<T>::value>
421 bytecode.push_back(
generator.getMemIndex(value));
425 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
426 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
428 bytecode.push_back(llvm::size(range));
429 for (
auto it : range)
434 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
435 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
437 append(field2, fields...);
441 template <
typename T>
442 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
443 appendInline(T value) {
444 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
445 const void *pointer = value.getAsOpaquePointer();
446 ByteCodeField fieldParts[numParts];
447 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
448 bytecode.append(fieldParts, fieldParts + numParts);
463 struct ByteCodeLiveRange {
464 using Set = llvm::IntervalMap<uint64_t, char, 16>;
465 using Allocator = Set::Allocator;
467 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470 void unionWith(
const ByteCodeLiveRange &rhs) {
471 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
473 liveness->insert(it.start(), it.stop(), 0);
477 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
478 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
488 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491 std::optional<unsigned> opRangeIndex;
494 std::optional<unsigned> typeRangeIndex;
497 std::optional<unsigned> valueRangeIndex;
501 void Generator::generate(ModuleOp module) {
502 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
503 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
504 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
505 pdl_interp::PDLInterpDialect::getRewriterModuleName());
506 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
510 allocateMemoryIndices(matcherFunc, rewriterModule);
513 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
514 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
515 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
516 for (
Operation &op : rewriterFunc.getOps())
517 generate(&op, rewriterByteCodeWriter);
519 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
520 "unexpected branches in rewriter function");
523 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
524 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528 ByteCodeAddr addr = blockToAddr[it.first];
529 for (
unsigned offsetToFix : it.second)
530 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
534 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
535 ModuleOp rewriterModule) {
538 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
539 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
540 auto processRewriterValue = [&](
Value val) {
541 valueToMemIndex.try_emplace(val, index++);
542 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
543 Type elementTy = rangeType.getElementType();
544 if (isa<pdl::TypeType>(elementTy))
545 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
546 else if (isa<pdl::ValueType>(elementTy))
547 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
552 processRewriterValue(arg);
555 processRewriterValue(result);
557 if (index > maxValueMemoryIndex)
558 maxValueMemoryIndex = index;
559 if (typeRangeIndex > maxTypeRangeMemoryIndex)
560 maxTypeRangeMemoryIndex = typeRangeIndex;
561 if (valueRangeIndex > maxValueRangeMemoryIndex)
562 maxValueRangeMemoryIndex = valueRangeIndex;
579 opToFirstIndex.try_emplace(op, index++);
581 for (
Block &block : region.getBlocks())
584 opToLastIndex.try_emplace(op, index++);
589 ByteCodeLiveRange::Allocator allocator;
594 valueToMemIndex[rootOpArg] = 0;
597 Liveness matcherLiveness(matcherFunc);
598 matcherFunc->walk([&](
Block *block) {
600 assert(info &&
"expected liveness info for block");
604 if (value == rootOpArg)
608 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
609 defRangeIt->second.liveness->insert(
610 opToFirstIndex[firstUseOrDef],
615 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
616 Type eleType = rangeTy.getElementType();
617 if (isa<pdl::OperationType>(eleType))
618 defRangeIt->second.opRangeIndex = 0;
619 else if (isa<pdl::TypeType>(eleType))
620 defRangeIt->second.typeRangeIndex = 0;
621 else if (isa<pdl::ValueType>(eleType))
622 defRangeIt->second.valueRangeIndex = 0;
627 for (
Value liveIn : info->
in()) {
632 if (liveIn.getParentRegion() == block->
getParent())
649 std::vector<ByteCodeLiveRange> allocatedIndices;
653 ByteCodeField numIndices = 1;
656 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
658 for (
auto &defIt : valueDefRanges) {
659 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
660 ByteCodeLiveRange &defRange = defIt.second;
663 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
664 ByteCodeLiveRange &existingRange = existingIndexIt.value();
665 if (!defRange.overlaps(existingRange)) {
666 existingRange.unionWith(defRange);
667 memIndex = existingIndexIt.index() + 1;
669 if (defRange.opRangeIndex) {
670 if (!existingRange.opRangeIndex)
671 existingRange.opRangeIndex = numOpRanges++;
672 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
673 }
else if (defRange.typeRangeIndex) {
674 if (!existingRange.typeRangeIndex)
675 existingRange.typeRangeIndex = numTypeRanges++;
676 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
677 }
else if (defRange.valueRangeIndex) {
678 if (!existingRange.valueRangeIndex)
679 existingRange.valueRangeIndex = numValueRanges++;
680 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
688 allocatedIndices.emplace_back(allocator);
689 ByteCodeLiveRange &newRange = allocatedIndices.back();
690 newRange.unionWith(defRange);
693 if (defRange.opRangeIndex) {
694 newRange.opRangeIndex = numOpRanges;
695 valueToRangeIndex[defIt.first] = numOpRanges++;
696 }
else if (defRange.typeRangeIndex) {
697 newRange.typeRangeIndex = numTypeRanges;
698 valueToRangeIndex[defIt.first] = numTypeRanges++;
699 }
else if (defRange.valueRangeIndex) {
700 newRange.valueRangeIndex = numValueRanges;
701 valueToRangeIndex[defIt.first] = numValueRanges++;
704 memIndex = allocatedIndices.size();
711 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices "
712 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
715 "Ran out of memory for allocated indices");
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
728 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (
Block *block : rpot) {
732 blockToAddr.try_emplace(block, matcherByteCode.size());
734 generate(&op, writer);
738 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
742 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
743 writer.appendInline(op->
getLoc());
746 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
747 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
748 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
749 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
750 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
751 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
752 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
753 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
754 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
755 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
756 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
757 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
758 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
759 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
760 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
761 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
762 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
763 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
764 pdl_interp::SwitchResultCountOp>(
765 [&](
auto interpOp) { this->generate(interpOp, writer); })
767 llvm_unreachable(
"unknown `pdl_interp` operation");
771 void Generator::generate(pdl_interp::ApplyConstraintOp op,
772 ByteCodeWriter &writer) {
776 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
777 writer.appendPDLValueList(op.getArgs());
778 writer.append(ByteCodeField(op.getIsNegated()));
780 writer.append(ByteCodeField(results.size()));
781 for (
Value result : results) {
785 writer.appendPDLValueKind(result);
788 if (isa<pdl::RangeType>(result.getType()))
789 writer.append(getRangeStorageIndex(result));
790 writer.append(result);
792 writer.append(op.getSuccessors());
794 void Generator::generate(pdl_interp::ApplyRewriteOp op,
795 ByteCodeWriter &writer) {
796 assert(externalRewriterToMemIndex.count(op.getName()) &&
797 "expected index for rewrite function");
798 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
799 writer.appendPDLValueList(op.getArgs());
802 writer.append(ByteCodeField(results.size()));
803 for (
Value result : results) {
806 writer.appendPDLValueKind(result);
809 if (isa<pdl::RangeType>(result.getType()))
810 writer.append(getRangeStorageIndex(result));
811 writer.append(result);
814 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
815 Value lhs = op.getLhs();
816 if (isa<pdl::RangeType>(lhs.
getType())) {
817 writer.append(OpCode::AreRangesEqual);
818 writer.appendPDLValueKind(lhs);
819 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
823 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
825 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
828 void Generator::generate(pdl_interp::CheckAttributeOp op,
829 ByteCodeWriter &writer) {
830 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
833 void Generator::generate(pdl_interp::CheckOperandCountOp op,
834 ByteCodeWriter &writer) {
835 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
836 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
839 void Generator::generate(pdl_interp::CheckOperationNameOp op,
840 ByteCodeWriter &writer) {
841 writer.append(OpCode::CheckOperationName, op.getInputOp(),
844 void Generator::generate(pdl_interp::CheckResultCountOp op,
845 ByteCodeWriter &writer) {
846 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
847 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
850 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
851 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
854 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
855 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
858 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
859 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
860 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
862 void Generator::generate(pdl_interp::CreateAttributeOp op,
863 ByteCodeWriter &writer) {
865 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
867 void Generator::generate(pdl_interp::CreateOperationOp op,
868 ByteCodeWriter &writer) {
869 writer.append(OpCode::CreateOperation, op.getResultOp(),
871 writer.appendPDLValueList(op.getInputOperands());
875 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
876 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
877 writer.append(std::get<0>(it), std::get<1>(it));
881 if (op.getInferredResultTypes())
884 writer.appendPDLValueList(op.getInputResultTypes());
886 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
890 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
891 .Case([&](pdl::ValueType) {
892 writer.append(OpCode::CreateDynamicValueRange);
895 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
896 writer.appendPDLValueList(op->getOperands());
898 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
900 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
902 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
903 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
904 getRangeStorageIndex(op.getResult()), op.getValue());
906 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
907 writer.append(OpCode::EraseOp, op.getInputOp());
909 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
912 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
913 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
914 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
915 .Default([](
Type) -> OpCode {
916 llvm_unreachable(
"unsupported element type");
918 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
920 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
921 writer.append(OpCode::Finalize);
923 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
925 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
926 writer.appendPDLValueKind(arg.
getType());
927 writer.append(curLoopLevel, op.getSuccessor());
929 if (curLoopLevel > maxLoopLevel)
930 maxLoopLevel = curLoopLevel;
931 generate(&op.getRegion(), writer);
934 void Generator::generate(pdl_interp::GetAttributeOp op,
935 ByteCodeWriter &writer) {
936 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
939 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
940 ByteCodeWriter &writer) {
941 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
943 void Generator::generate(pdl_interp::GetDefiningOpOp op,
944 ByteCodeWriter &writer) {
945 writer.append(OpCode::GetDefiningOp, op.getInputOp());
946 writer.appendPDLValue(op.getValue());
948 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
949 uint32_t index = op.getIndex();
951 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
953 writer.append(OpCode::GetOperandN, index);
954 writer.append(op.getInputOp(), op.getValue());
956 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
957 Value result = op.getValue();
958 std::optional<uint32_t> index = op.getIndex();
959 writer.append(OpCode::GetOperands,
962 if (isa<pdl::RangeType>(result.
getType()))
963 writer.append(getRangeStorageIndex(result));
966 writer.append(result);
968 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
969 uint32_t index = op.getIndex();
971 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
973 writer.append(OpCode::GetResultN, index);
974 writer.append(op.getInputOp(), op.getValue());
976 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
977 Value result = op.getValue();
978 std::optional<uint32_t> index = op.getIndex();
979 writer.append(OpCode::GetResults,
982 if (isa<pdl::RangeType>(result.
getType()))
983 writer.append(getRangeStorageIndex(result));
986 writer.append(result);
988 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
989 Value operations = op.getOperations();
990 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
991 writer.append(OpCode::GetUsers, operations, rangeIndex);
992 writer.appendPDLValue(op.getValue());
994 void Generator::generate(pdl_interp::GetValueTypeOp op,
995 ByteCodeWriter &writer) {
996 if (isa<pdl::RangeType>(op.getType())) {
997 Value result = op.getResult();
998 writer.append(OpCode::GetValueRangeTypes, result,
999 getRangeStorageIndex(result), op.getValue());
1001 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1004 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1005 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1007 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1008 ByteCodeField patternIndex =
patterns.size();
1009 patterns.emplace_back(PDLByteCodePattern::create(
1010 op, configMap.lookup(op),
1011 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1012 writer.append(OpCode::RecordMatch, patternIndex,
1014 writer.appendPDLValueList(op.getInputs());
1016 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1017 writer.append(OpCode::ReplaceOp, op.getInputOp());
1018 writer.appendPDLValueList(op.getReplValues());
1020 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1021 ByteCodeWriter &writer) {
1022 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1023 op.getCaseValuesAttr(), op.getSuccessors());
1025 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1026 ByteCodeWriter &writer) {
1027 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1028 op.getCaseValuesAttr(), op.getSuccessors());
1030 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1031 ByteCodeWriter &writer) {
1032 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1033 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1035 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1036 op.getSuccessors());
1038 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1039 ByteCodeWriter &writer) {
1040 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1041 op.getCaseValuesAttr(), op.getSuccessors());
1043 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1044 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1045 op.getSuccessors());
1047 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1048 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1049 op.getSuccessors());
1056 PDLByteCode::PDLByteCode(
1057 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1059 llvm::StringMap<PDLConstraintFunction> constraintFns,
1060 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1061 : configs(std::move(configs)) {
1062 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1063 rewriterByteCode,
patterns, maxValueMemoryIndex,
1064 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1065 maxLoopLevel, constraintFns, rewriteFns, configMap);
1069 for (
auto &it : constraintFns)
1070 constraintFunctions.push_back(std::move(it.second));
1071 for (
auto &it : rewriteFns)
1072 rewriteFunctions.push_back(std::move(it.second));
1078 state.memory.resize(maxValueMemoryIndex,
nullptr);
1079 state.opRangeMemory.resize(maxOpRangeCount);
1080 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1081 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1082 state.loopIndex.resize(maxLoopLevel, 0);
1083 state.currentPatternBenefits.reserve(
patterns.size());
1085 state.currentPatternBenefits.push_back(pattern.getBenefit());
1096 class ByteCodeRewriteResultList :
public PDLResultList {
1098 ByteCodeRewriteResultList(
unsigned maxNumResults)
1099 : PDLResultList(maxNumResults) {}
1106 return allocatedTypeRanges;
1111 return allocatedValueRanges;
1116 class ByteCodeExecutor {
1122 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1124 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1131 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1132 typeRangeMemory(typeRangeMemory),
1133 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1134 valueRangeMemory(valueRangeMemory),
1135 allocatedValueRangeMemory(allocatedValueRangeMemory),
1136 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1138 constraintFunctions(constraintFunctions),
1139 rewriteFunctions(rewriteFunctions) {}
1147 std::optional<Location> mainRewriteLoc = {});
1153 void executeAreEqual();
1154 void executeAreRangesEqual();
1155 void executeBranch();
1156 void executeCheckOperandCount();
1157 void executeCheckOperationName();
1158 void executeCheckResultCount();
1159 void executeCheckTypes();
1160 void executeContinue();
1161 void executeCreateConstantTypeRange();
1164 template <
typename T>
1165 void executeDynamicCreateRange(StringRef type);
1167 template <
typename T,
typename Range, PDLValue::Kind kind>
1168 void executeExtract();
1169 void executeFinalize();
1170 void executeForEach();
1171 void executeGetAttribute();
1172 void executeGetAttributeType();
1173 void executeGetDefiningOp();
1174 void executeGetOperand(
unsigned index);
1175 void executeGetOperands();
1176 void executeGetResult(
unsigned index);
1177 void executeGetResults();
1178 void executeGetUsers();
1179 void executeGetValueType();
1180 void executeGetValueRangeTypes();
1181 void executeIsNotNull();
1185 void executeSwitchAttribute();
1186 void executeSwitchOperandCount();
1187 void executeSwitchOperationName();
1188 void executeSwitchResultCount();
1189 void executeSwitchType();
1190 void executeSwitchTypes();
1191 void processNativeFunResults(ByteCodeRewriteResultList &results,
1192 unsigned numResults,
1193 LogicalResult &rewriteResult);
1196 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1200 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1201 curCodeIt = resumeCodeIt.pop_back_val();
1205 const ByteCodeField *getPrevCodeIt()
const {
1208 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1212 return curCodeIt - 1;
1218 template <
typename T = ByteCodeField>
1219 T read(
size_t skipN = 0) {
1221 return readImpl<T>();
1223 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1226 template <
typename ValueT,
typename T>
1229 for (
unsigned i = 0, e = read(); i != e; ++i)
1230 list.push_back(read<ValueT>());
1236 for (
unsigned i = 0, e = read(); i != e; ++i) {
1237 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1238 list.push_back(read<Type>());
1240 TypeRange *values = read<TypeRange *>();
1241 list.append(values->begin(), values->end());
1246 for (
unsigned i = 0, e = read(); i != e; ++i) {
1247 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1248 list.push_back(read<Value>());
1251 list.append(values->begin(), values->end());
1257 template <
typename T>
1258 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1260 const void *pointer;
1261 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1262 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1263 return T::getFromOpaquePointer(pointer);
1266 void skip(
size_t skipN) { curCodeIt += skipN; }
1269 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1271 void selectJump(
size_t destIndex) {
1272 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1276 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1277 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1279 llvm::dbgs() <<
" * Value: " << value <<
"\n"
1281 llvm::interleaveComma(cases, llvm::dbgs());
1282 llvm::dbgs() <<
"\n";
1287 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1288 if (cmp(*it, value))
1289 return selectJump(
size_t((it - cases.begin()) + 1));
1290 selectJump(
size_t(0));
1294 void storeToMemory(
unsigned index,
const void *value) {
1295 memory[index] = value;
1299 template <
typename T>
1300 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1301 storeToMemory(
unsigned index, T value) {
1302 memory[index] = value.getAsOpaquePointer();
1307 template <
typename T>
1308 const void *readFromMemory() {
1309 size_t index = *curCodeIt++;
1314 index < memory.size())
1315 return memory[index];
1318 return uniquedMemory[index - memory.size()];
1320 template <
typename T>
1321 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1322 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1324 template <
typename T>
1325 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1328 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1330 template <
typename T>
1331 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1332 switch (read<PDLValue::Kind>()) {
1333 case PDLValue::Kind::Attribute:
1334 return read<Attribute>();
1335 case PDLValue::Kind::Operation:
1336 return read<Operation *>();
1337 case PDLValue::Kind::Type:
1338 return read<Type>();
1339 case PDLValue::Kind::Value:
1340 return read<Value>();
1341 case PDLValue::Kind::TypeRange:
1342 return read<TypeRange *>();
1343 case PDLValue::Kind::ValueRange:
1344 return read<ValueRange *>();
1346 llvm_unreachable(
"unhandled PDLValue::Kind");
1348 template <
typename T>
1349 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1350 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1351 "unexpected ByteCode address size");
1352 ByteCodeAddr result;
1353 std::memcpy(&result, curCodeIt,
sizeof(ByteCodeAddr));
1357 template <
typename T>
1358 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1359 return *curCodeIt++;
1361 template <
typename T>
1362 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1368 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1369 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1370 unsigned rangeIndex) {
1372 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1374 if (range.empty()) {
1375 rangeMemory[rangeIndex] = {};
1378 llvm::OwningArrayRef<T> storage(llvm::size(range));
1383 allocatedRangeMemory.emplace_back(std::move(storage));
1384 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1386 memory[memIndex] = &rangeMemory[rangeIndex];
1390 if constexpr (std::is_same_v<T, Type>) {
1391 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1392 }
else if constexpr (std::is_same_v<T, Value>) {
1393 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1395 llvm_unreachable(
"unhandled range type");
1400 const ByteCodeField *curCodeIt;
1409 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1411 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1426 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1427 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1428 ByteCodeField fun_idx = read();
1430 readList<PDLValue>(args);
1433 llvm::dbgs() <<
" * Arguments: ";
1434 llvm::interleaveComma(args, llvm::dbgs());
1435 llvm::dbgs() <<
"\n";
1438 ByteCodeField isNegated = read();
1440 llvm::dbgs() <<
" * isNegated: " << isNegated <<
"\n";
1441 llvm::interleaveComma(args, llvm::dbgs());
1444 ByteCodeField numResults = read();
1445 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1446 ByteCodeRewriteResultList results(numResults);
1447 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1450 if (succeeded(rewriteResult)) {
1451 llvm::dbgs() <<
" * Constraint succeeded\n";
1452 llvm::dbgs() <<
" * Results: ";
1453 llvm::interleaveComma(constraintResults, llvm::dbgs());
1454 llvm::dbgs() <<
"\n";
1456 llvm::dbgs() <<
" * Constraint failed\n";
1459 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1460 "native PDL rewrite function succeeded but returned "
1461 "unexpected number of results");
1462 processNativeFunResults(results, numResults, rewriteResult);
1465 selectJump(isNegated != succeeded(rewriteResult));
1468 LogicalResult ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1469 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1470 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1472 readList<PDLValue>(args);
1475 llvm::dbgs() <<
" * Arguments: ";
1476 llvm::interleaveComma(args, llvm::dbgs());
1480 ByteCodeField numResults = read();
1481 ByteCodeRewriteResultList results(numResults);
1482 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1484 assert(results.getResults().size() == numResults &&
1485 "native PDL rewrite function returned unexpected number of results");
1487 processNativeFunResults(results, numResults, rewriteResult);
1489 if (failed(rewriteResult)) {
1490 LLVM_DEBUG(llvm::dbgs() <<
" - Failed");
1496 void ByteCodeExecutor::processNativeFunResults(
1497 ByteCodeRewriteResultList &results,
unsigned numResults,
1498 LogicalResult &rewriteResult) {
1499 if (failed(rewriteResult)) {
1502 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1504 if (resultKind == PDLValue::Kind::TypeRange ||
1505 resultKind == PDLValue::Kind::ValueRange) {
1515 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1518 PDLValue result = results.getResults()[resultIdx];
1519 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1520 assert(result.getKind() == resultKind &&
1521 "native PDL rewrite function returned an unexpected type of "
1525 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1526 unsigned rangeIndex = read();
1527 typeRangeMemory[rangeIndex] = *typeRange;
1528 memory[read()] = &typeRangeMemory[rangeIndex];
1529 }
else if (std::optional<ValueRange> valueRange =
1531 unsigned rangeIndex = read();
1532 valueRangeMemory[rangeIndex] = *valueRange;
1533 memory[read()] = &valueRangeMemory[rangeIndex];
1535 memory[read()] = result.getAsOpaquePointer();
1540 for (
auto &it : results.getAllocatedTypeRanges())
1541 allocatedTypeRangeMemory.push_back(std::move(it));
1542 for (
auto &it : results.getAllocatedValueRanges())
1543 allocatedValueRangeMemory.push_back(std::move(it));
1546 void ByteCodeExecutor::executeAreEqual() {
1547 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1548 const void *lhs = read<const void *>();
1549 const void *rhs = read<const void *>();
1551 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1552 selectJump(lhs == rhs);
1555 void ByteCodeExecutor::executeAreRangesEqual() {
1556 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1558 const void *lhs = read<const void *>();
1559 const void *rhs = read<const void *>();
1561 switch (valueKind) {
1562 case PDLValue::Kind::TypeRange: {
1565 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1566 selectJump(*lhsRange == *rhsRange);
1569 case PDLValue::Kind::ValueRange: {
1570 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1571 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1572 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1573 selectJump(*lhsRange == *rhsRange);
1577 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1581 void ByteCodeExecutor::executeBranch() {
1582 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1583 curCodeIt = &code[read<ByteCodeAddr>()];
1586 void ByteCodeExecutor::executeCheckOperandCount() {
1587 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1589 uint32_t expectedCount = read<uint32_t>();
1590 bool compareAtLeast = read();
1592 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n"
1593 <<
" * Expected: " << expectedCount <<
"\n"
1594 <<
" * Comparator: "
1595 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1602 void ByteCodeExecutor::executeCheckOperationName() {
1603 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1607 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n"
1608 <<
" * Expected: \"" << expectedName <<
"\"\n");
1609 selectJump(op->
getName() == expectedName);
1612 void ByteCodeExecutor::executeCheckResultCount() {
1613 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1615 uint32_t expectedCount = read<uint32_t>();
1616 bool compareAtLeast = read();
1618 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n"
1619 <<
" * Expected: " << expectedCount <<
"\n"
1620 <<
" * Comparator: "
1621 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1628 void ByteCodeExecutor::executeCheckTypes() {
1629 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1632 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1634 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1637 void ByteCodeExecutor::executeContinue() {
1638 ByteCodeField level = read();
1639 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n"
1640 <<
" * Level: " << level <<
"\n");
1645 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1646 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateConstantTypeRange:\n");
1647 unsigned memIndex = read();
1648 unsigned rangeIndex = read();
1649 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1651 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1652 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1656 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1658 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1660 unsigned memIndex = read();
1662 readList(state.operands);
1663 for (
unsigned i = 0, e = read(); i != e; ++i) {
1664 StringAttr name = read<StringAttr>();
1666 state.addAttribute(name, attr);
1671 unsigned numResults = read();
1673 InferTypeOpInterface::Concept *inferInterface =
1674 state.name.getInterface<InferTypeOpInterface>();
1675 assert(inferInterface &&
1676 "expected operation to provide InferTypeOpInterface");
1679 if (failed(inferInterface->inferReturnTypes(
1680 state.getContext(), state.location, state.operands,
1681 state.attributes.getDictionary(state.getContext()),
1682 state.getRawProperties(), state.regions, state.types)))
1686 for (
unsigned i = 0; i != numResults; ++i) {
1687 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1688 state.types.push_back(read<Type>());
1690 TypeRange *resultTypes = read<TypeRange *>();
1691 state.types.append(resultTypes->begin(), resultTypes->end());
1697 memory[memIndex] = resultOp;
1700 llvm::dbgs() <<
" * Attributes: "
1701 << state.attributes.getDictionary(state.getContext())
1702 <<
"\n * Operands: ";
1703 llvm::interleaveComma(state.operands, llvm::dbgs());
1704 llvm::dbgs() <<
"\n * Result Types: ";
1705 llvm::interleaveComma(state.types, llvm::dbgs());
1706 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1710 template <
typename T>
1711 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1712 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateDynamic" << type <<
"Range:\n");
1713 unsigned memIndex = read();
1714 unsigned rangeIndex = read();
1719 llvm::dbgs() <<
"\n * " << type <<
"s: ";
1720 llvm::interleaveComma(values, llvm::dbgs());
1721 llvm::dbgs() <<
"\n";
1724 assignRangeToMemory(values, memIndex, rangeIndex);
1728 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1731 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1735 template <
typename T,
typename Range, PDLValue::Kind kind>
1736 void ByteCodeExecutor::executeExtract() {
1737 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" <<
kind <<
":\n");
1738 Range *range = read<Range *>();
1739 unsigned index = read<uint32_t>();
1740 unsigned memIndex = read();
1743 memory[memIndex] =
nullptr;
1747 T result = index < range->
size() ? (*range)[index] : T();
1748 LLVM_DEBUG(llvm::dbgs() <<
" * " <<
kind <<
"s(" << range->
size() <<
")\n"
1749 <<
" * Index: " << index <<
"\n"
1750 <<
" * Result: " << result <<
"\n");
1751 storeToMemory(memIndex, result);
1754 void ByteCodeExecutor::executeFinalize() {
1755 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1758 void ByteCodeExecutor::executeForEach() {
1759 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1760 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1761 unsigned rangeIndex = read();
1762 unsigned memIndex = read();
1763 const void *value =
nullptr;
1765 switch (read<PDLValue::Kind>()) {
1766 case PDLValue::Kind::Operation: {
1767 unsigned &index = loopIndex[read()];
1769 assert(index <= array.size() &&
"iterated past the end");
1770 if (index < array.size()) {
1771 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1772 value = array[index];
1776 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1778 selectJump(
size_t(0));
1782 llvm_unreachable(
"unexpected `ForEach` value kind");
1786 memory[memIndex] = value;
1787 pushCodeIt(prevCodeIt);
1790 read<ByteCodeAddr>();
1793 void ByteCodeExecutor::executeGetAttribute() {
1794 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1795 unsigned memIndex = read();
1797 StringAttr attrName = read<StringAttr>();
1800 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1801 <<
" * Attribute: " << attrName <<
"\n"
1802 <<
" * Result: " << attr <<
"\n");
1806 void ByteCodeExecutor::executeGetAttributeType() {
1807 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1808 unsigned memIndex = read();
1811 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1812 type = typedAttr.getType();
1814 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n"
1815 <<
" * Result: " << type <<
"\n");
1819 void ByteCodeExecutor::executeGetDefiningOp() {
1820 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1821 unsigned memIndex = read();
1823 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1824 Value value = read<Value>();
1827 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1830 if (values && !values->empty()) {
1831 op = values->front().getDefiningOp();
1833 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1836 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1837 memory[memIndex] = op;
1840 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1842 unsigned memIndex = read();
1846 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1847 <<
" * Index: " << index <<
"\n"
1848 <<
" * Result: " << operand <<
"\n");
1855 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1858 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1863 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1867 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1868 LLVM_DEBUG(llvm::dbgs()
1869 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1872 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1876 unsigned startIndex =
1877 std::accumulate(segments.begin(), segments.begin() + index, 0);
1878 values = values.slice(startIndex, *std::next(segments.begin(), index));
1880 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", "
1881 << *std::next(segments.begin(), index) <<
"]\n");
1887 }
else if (values.size() >= index) {
1888 LLVM_DEBUG(llvm::dbgs()
1889 <<
" * Treating values as trailing variadic range\n");
1890 values = values.drop_front(index);
1899 valueRangeMemory[rangeIndex] = values;
1900 return &valueRangeMemory[rangeIndex];
1904 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1907 void ByteCodeExecutor::executeGetOperands() {
1908 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1909 unsigned index = read<uint32_t>();
1911 ByteCodeField rangeIndex = read();
1913 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1914 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1917 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1918 memory[read()] = result;
1921 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1923 unsigned memIndex = read();
1927 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1928 <<
" * Index: " << index <<
"\n"
1929 <<
" * Result: " << result <<
"\n");
1933 void ByteCodeExecutor::executeGetResults() {
1934 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1935 unsigned index = read<uint32_t>();
1937 ByteCodeField rangeIndex = read();
1939 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1940 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1943 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1944 memory[read()] = result;
1947 void ByteCodeExecutor::executeGetUsers() {
1948 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1949 unsigned memIndex = read();
1950 unsigned rangeIndex = read();
1951 OwningOpRange &range = opRangeMemory[rangeIndex];
1952 memory[memIndex] = ⦥
1954 range = OwningOpRange();
1955 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1957 Value value = read<Value>();
1960 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1971 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1972 llvm::interleaveComma(*values, llvm::dbgs());
1973 llvm::dbgs() <<
"\n";
1978 for (
Value value : *values)
1980 range = OwningOpRange(users.size());
1984 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1987 void ByteCodeExecutor::executeGetValueType() {
1988 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1989 unsigned memIndex = read();
1990 Value value = read<Value>();
1993 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n"
1994 <<
" * Result: " << type <<
"\n");
1998 void ByteCodeExecutor::executeGetValueRangeTypes() {
1999 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
2000 unsigned memIndex = read();
2001 unsigned rangeIndex = read();
2004 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
2005 memory[memIndex] =
nullptr;
2010 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
2011 llvm::interleaveComma(*values, llvm::dbgs());
2012 llvm::dbgs() <<
"\n * Result: ";
2013 llvm::interleaveComma(values->
getType(), llvm::dbgs());
2014 llvm::dbgs() <<
"\n";
2016 typeRangeMemory[rangeIndex] = values->
getType();
2017 memory[memIndex] = &typeRangeMemory[rangeIndex];
2020 void ByteCodeExecutor::executeIsNotNull() {
2021 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
2022 const void *value = read<const void *>();
2024 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
2025 selectJump(value !=
nullptr);
2028 void ByteCodeExecutor::executeRecordMatch(
2031 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
2032 unsigned patternIndex = read();
2034 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2039 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
2048 unsigned numMatchLocs = read();
2050 matchLocs.reserve(numMatchLocs);
2051 for (
unsigned i = 0; i != numMatchLocs; ++i)
2052 matchLocs.push_back(read<Operation *>()->getLoc());
2055 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n"
2056 <<
" * Location: " << matchLoc <<
"\n");
2057 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2063 unsigned numInputs = read();
2064 match.values.reserve(numInputs);
2065 match.typeRangeValues.reserve(numInputs);
2066 match.valueRangeValues.reserve(numInputs);
2067 for (
unsigned i = 0; i < numInputs; ++i) {
2068 switch (read<PDLValue::Kind>()) {
2069 case PDLValue::Kind::TypeRange:
2070 match.typeRangeValues.push_back(*read<TypeRange *>());
2071 match.values.push_back(&match.typeRangeValues.back());
2073 case PDLValue::Kind::ValueRange:
2074 match.valueRangeValues.push_back(*read<ValueRange *>());
2075 match.values.push_back(&match.valueRangeValues.back());
2078 match.values.push_back(read<const void *>());
2086 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
2092 llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
2094 llvm::interleaveComma(args, llvm::dbgs());
2095 llvm::dbgs() <<
"\n";
2100 void ByteCodeExecutor::executeSwitchAttribute() {
2101 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
2103 ArrayAttr cases = read<ArrayAttr>();
2104 handleSwitch(value, cases);
2107 void ByteCodeExecutor::executeSwitchOperandCount() {
2108 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
2110 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2112 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2116 void ByteCodeExecutor::executeSwitchOperationName() {
2117 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
2119 size_t caseCount = read();
2125 const ByteCodeField *prevCodeIt = curCodeIt;
2126 llvm::dbgs() <<
" * Value: " << value <<
"\n"
2128 llvm::interleaveComma(
2129 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2130 [&](
size_t) {
return read<OperationName>(); }),
2132 llvm::dbgs() <<
"\n";
2133 curCodeIt = prevCodeIt;
2137 for (
size_t i = 0; i != caseCount; ++i) {
2138 if (read<OperationName>() == value) {
2139 curCodeIt += (caseCount - i - 1);
2140 return selectJump(i + 1);
2143 selectJump(
size_t(0));
2146 void ByteCodeExecutor::executeSwitchResultCount() {
2147 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
2149 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2151 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2155 void ByteCodeExecutor::executeSwitchType() {
2156 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2157 Type value = read<Type>();
2158 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2159 handleSwitch(value, cases);
2162 void ByteCodeExecutor::executeSwitchTypes() {
2163 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2165 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2167 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2168 return selectJump(
size_t(0));
2170 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2171 return value == caseValue.getAsValueRange<TypeAttr>();
2178 std::optional<Location> mainRewriteLoc) {
2181 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2183 OpCode opCode =
static_cast<OpCode
>(read());
2185 case ApplyConstraint:
2186 executeApplyConstraint(rewriter);
2189 if (failed(executeApplyRewrite(rewriter)))
2195 case AreRangesEqual:
2196 executeAreRangesEqual();
2201 case CheckOperandCount:
2202 executeCheckOperandCount();
2204 case CheckOperationName:
2205 executeCheckOperationName();
2207 case CheckResultCount:
2208 executeCheckResultCount();
2211 executeCheckTypes();
2216 case CreateConstantTypeRange:
2217 executeCreateConstantTypeRange();
2219 case CreateOperation:
2220 executeCreateOperation(rewriter, *mainRewriteLoc);
2222 case CreateDynamicTypeRange:
2223 executeDynamicCreateRange<Type>(
"Type");
2225 case CreateDynamicValueRange:
2226 executeDynamicCreateRange<Value>(
"Value");
2229 executeEraseOp(rewriter);
2232 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2235 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2238 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2242 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2248 executeGetAttribute();
2250 case GetAttributeType:
2251 executeGetAttributeType();
2254 executeGetDefiningOp();
2260 unsigned index = opCode - GetOperand0;
2261 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2262 executeGetOperand(index);
2266 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2267 executeGetOperand(read<uint32_t>());
2270 executeGetOperands();
2276 unsigned index = opCode - GetResult0;
2277 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2278 executeGetResult(index);
2282 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2283 executeGetResult(read<uint32_t>());
2286 executeGetResults();
2292 executeGetValueType();
2294 case GetValueRangeTypes:
2295 executeGetValueRangeTypes();
2302 "expected matches to be provided when executing the matcher");
2303 executeRecordMatch(rewriter, *matches);
2306 executeReplaceOp(rewriter);
2308 case SwitchAttribute:
2309 executeSwitchAttribute();
2311 case SwitchOperandCount:
2312 executeSwitchOperandCount();
2314 case SwitchOperationName:
2315 executeSwitchOperationName();
2317 case SwitchResultCount:
2318 executeSwitchResultCount();
2321 executeSwitchType();
2324 executeSwitchTypes();
2327 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2335 state.memory[0] = op;
2338 ByteCodeExecutor executor(
2339 matcherByteCode.data(), state.memory, state.opRangeMemory,
2340 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2341 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2342 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2343 constraintFunctions, rewriteFunctions);
2344 LogicalResult executeResult = executor.execute(rewriter, &matches);
2345 (void)executeResult;
2346 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2349 llvm::stable_sort(matches,
2350 [](
const MatchResult &lhs,
const MatchResult &rhs) {
2351 return lhs.benefit > rhs.benefit;
2356 const MatchResult &match,
2358 auto *configSet =
match.pattern->getConfigSet();
2360 configSet->notifyRewriteBegin(rewriter);
2366 ByteCodeExecutor executor(
2367 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2368 state.opRangeMemory, state.typeRangeMemory,
2369 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2370 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2371 rewriterByteCode, state.currentPatternBenefits,
patterns,
2372 constraintFunctions, rewriteFunctions);
2373 LogicalResult result =
2374 executor.execute(rewriter,
nullptr,
match.location);
2377 configSet->notifyRewriteEnd(rewriter);
2386 LLVM_DEBUG(llvm::dbgs() <<
" and rollback is not supported - aborting");
2387 llvm::report_fatal_error(
2388 "Native PDL Rewrite failed, but the pattern "
2389 "rewriter doesn't support recovery. Failable pattern rewrites should "
2390 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...