19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
54 benefit, ctx, generatedOps);
66 currentPatternBenefits[patternIndex] = benefit;
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
82 enum OpCode : ByteCodeField {
104 CreateConstantTypeRange,
108 CreateDynamicTypeRange,
110 CreateDynamicValueRange,
185 struct ByteCodeLiveRange;
186 struct ByteCodeWriter;
189 template <
typename T,
typename... Args>
190 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
195 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
199 ByteCodeField &maxValueMemoryIndex,
200 ByteCodeField &maxOpRangeMemoryIndex,
201 ByteCodeField &maxTypeRangeMemoryIndex,
202 ByteCodeField &maxValueRangeMemoryIndex,
203 ByteCodeField &maxLoopLevel,
204 llvm::StringMap<PDLConstraintFunction> &constraintFns,
205 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
207 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
209 maxValueMemoryIndex(maxValueMemoryIndex),
210 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
211 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
212 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
213 maxLoopLevel(maxLoopLevel), configMap(configMap) {
215 constraintToMemIndex.try_emplace(it.value().first(), it.index());
217 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
221 void generate(ModuleOp module);
224 ByteCodeField &getMemIndex(
Value value) {
225 assert(valueToMemIndex.count(value) &&
226 "expected memory index to be assigned");
227 return valueToMemIndex[value];
231 ByteCodeField &getRangeStorageIndex(
Value value) {
232 assert(valueToRangeIndex.count(value) &&
233 "expected range index to be assigned");
234 return valueToRangeIndex[value];
239 template <
typename T>
240 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
242 const void *opaqueVal = val.getAsOpaquePointer();
245 auto it = uniquedDataToMemIndex.try_emplace(
246 opaqueVal, maxValueMemoryIndex + uniquedData.size());
248 uniquedData.push_back(opaqueVal);
249 return it.first->second;
255 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
256 ModuleOp rewriterModule);
259 void generate(
Region *region, ByteCodeWriter &writer);
260 void generate(
Operation *op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
308 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
312 llvm::StringMap<ByteCodeField> constraintToMemIndex;
316 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
323 ByteCodeField curLoopLevel = 0;
332 std::vector<const void *> &uniquedData;
336 ByteCodeField &maxValueMemoryIndex;
337 ByteCodeField &maxOpRangeMemoryIndex;
338 ByteCodeField &maxTypeRangeMemoryIndex;
339 ByteCodeField &maxValueRangeMemoryIndex;
340 ByteCodeField &maxLoopLevel;
347 struct ByteCodeWriter {
352 void append(ByteCodeField field) { bytecode.push_back(field); }
353 void append(OpCode opCode) { bytecode.push_back(opCode); }
356 void append(ByteCodeAddr field) {
357 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
358 "unexpected ByteCode address size");
360 ByteCodeField fieldParts[2];
361 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
362 bytecode.append({fieldParts[0], fieldParts[1]});
367 void append(
Block *successor) {
370 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
371 append(ByteCodeAddr(0));
377 for (
Block *successor : successors)
383 bytecode.push_back(values.size());
384 for (
Value value : values)
385 appendPDLValue(value);
389 void appendPDLValue(
Value value) {
390 appendPDLValueKind(value);
395 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
398 void appendPDLValueKind(
Type type) {
401 .Case<pdl::AttributeType>(
402 [](
Type) {
return PDLValue::Kind::Attribute; })
403 .Case<pdl::OperationType>(
404 [](
Type) {
return PDLValue::Kind::Operation; })
405 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
406 if (isa<pdl::TypeType>(rangeTy.getElementType()))
407 return PDLValue::Kind::TypeRange;
408 return PDLValue::Kind::ValueRange;
410 .Case<pdl::TypeType>([](
Type) {
return PDLValue::Kind::Type; })
411 .Case<pdl::ValueType>([](
Type) {
return PDLValue::Kind::Value; });
412 bytecode.push_back(
static_cast<ByteCodeField
>(
kind));
417 template <
typename T>
418 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
419 std::is_pointer<T>::value>
421 bytecode.push_back(
generator.getMemIndex(value));
425 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
426 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
428 bytecode.push_back(llvm::size(range));
429 for (
auto it : range)
434 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
435 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
437 append(field2, fields...);
441 template <
typename T>
442 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
443 appendInline(T value) {
444 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
445 const void *pointer = value.getAsOpaquePointer();
446 ByteCodeField fieldParts[numParts];
447 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
448 bytecode.append(fieldParts, fieldParts + numParts);
463 struct ByteCodeLiveRange {
464 using Set = llvm::IntervalMap<uint64_t, char, 16>;
465 using Allocator = Set::Allocator;
467 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470 void unionWith(
const ByteCodeLiveRange &rhs) {
471 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
473 liveness->insert(it.start(), it.stop(), 0);
477 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
478 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
488 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491 std::optional<unsigned> opRangeIndex;
494 std::optional<unsigned> typeRangeIndex;
497 std::optional<unsigned> valueRangeIndex;
501 void Generator::generate(ModuleOp module) {
502 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
503 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
504 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
505 pdl_interp::PDLInterpDialect::getRewriterModuleName());
506 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
510 allocateMemoryIndices(matcherFunc, rewriterModule);
513 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
514 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
515 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
516 for (
Operation &op : rewriterFunc.getOps())
517 generate(&op, rewriterByteCodeWriter);
519 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
520 "unexpected branches in rewriter function");
523 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
524 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528 ByteCodeAddr addr = blockToAddr[it.first];
529 for (
unsigned offsetToFix : it.second)
530 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
534 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
535 ModuleOp rewriterModule) {
538 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
539 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
540 auto processRewriterValue = [&](
Value val) {
541 valueToMemIndex.try_emplace(val, index++);
542 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
543 Type elementTy = rangeType.getElementType();
544 if (isa<pdl::TypeType>(elementTy))
545 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
546 else if (isa<pdl::ValueType>(elementTy))
547 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
552 processRewriterValue(arg);
555 processRewriterValue(result);
557 if (index > maxValueMemoryIndex)
558 maxValueMemoryIndex = index;
559 if (typeRangeIndex > maxTypeRangeMemoryIndex)
560 maxTypeRangeMemoryIndex = typeRangeIndex;
561 if (valueRangeIndex > maxValueRangeMemoryIndex)
562 maxValueRangeMemoryIndex = valueRangeIndex;
579 opToFirstIndex.try_emplace(op, index++);
581 for (
Block &block : region.getBlocks())
584 opToLastIndex.try_emplace(op, index++);
589 ByteCodeLiveRange::Allocator allocator;
594 valueToMemIndex[rootOpArg] = 0;
597 Liveness matcherLiveness(matcherFunc);
598 matcherFunc->walk([&](
Block *block) {
600 assert(info &&
"expected liveness info for block");
604 if (value == rootOpArg)
608 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
609 defRangeIt->second.liveness->insert(
610 opToFirstIndex[firstUseOrDef],
615 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
616 Type eleType = rangeTy.getElementType();
617 if (isa<pdl::OperationType>(eleType))
618 defRangeIt->second.opRangeIndex = 0;
619 else if (isa<pdl::TypeType>(eleType))
620 defRangeIt->second.typeRangeIndex = 0;
621 else if (isa<pdl::ValueType>(eleType))
622 defRangeIt->second.valueRangeIndex = 0;
627 for (
Value liveIn : info->
in()) {
632 if (liveIn.getParentRegion() == block->
getParent())
649 std::vector<ByteCodeLiveRange> allocatedIndices;
653 ByteCodeField numIndices = 1;
656 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
658 for (
auto &defIt : valueDefRanges) {
659 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
660 ByteCodeLiveRange &defRange = defIt.second;
663 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
664 ByteCodeLiveRange &existingRange = existingIndexIt.value();
665 if (!defRange.overlaps(existingRange)) {
666 existingRange.unionWith(defRange);
667 memIndex = existingIndexIt.index() + 1;
669 if (defRange.opRangeIndex) {
670 if (!existingRange.opRangeIndex)
671 existingRange.opRangeIndex = numOpRanges++;
672 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
673 }
else if (defRange.typeRangeIndex) {
674 if (!existingRange.typeRangeIndex)
675 existingRange.typeRangeIndex = numTypeRanges++;
676 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
677 }
else if (defRange.valueRangeIndex) {
678 if (!existingRange.valueRangeIndex)
679 existingRange.valueRangeIndex = numValueRanges++;
680 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
688 allocatedIndices.emplace_back(allocator);
689 ByteCodeLiveRange &newRange = allocatedIndices.back();
690 newRange.unionWith(defRange);
693 if (defRange.opRangeIndex) {
694 newRange.opRangeIndex = numOpRanges;
695 valueToRangeIndex[defIt.first] = numOpRanges++;
696 }
else if (defRange.typeRangeIndex) {
697 newRange.typeRangeIndex = numTypeRanges;
698 valueToRangeIndex[defIt.first] = numTypeRanges++;
699 }
else if (defRange.valueRangeIndex) {
700 newRange.valueRangeIndex = numValueRanges;
701 valueToRangeIndex[defIt.first] = numValueRanges++;
704 memIndex = allocatedIndices.size();
711 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices "
712 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
715 "Ran out of memory for allocated indices");
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
728 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (
Block *block : rpot) {
732 blockToAddr.try_emplace(block, matcherByteCode.size());
734 generate(&op, writer);
738 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
742 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
743 writer.appendInline(op->
getLoc());
746 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
747 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
748 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
749 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
750 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
751 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
752 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
753 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
754 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
755 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
756 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
757 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
758 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
759 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
760 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
761 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
762 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
763 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
764 pdl_interp::SwitchResultCountOp>(
765 [&](
auto interpOp) { this->generate(interpOp, writer); })
767 llvm_unreachable(
"unknown `pdl_interp` operation");
771 void Generator::generate(pdl_interp::ApplyConstraintOp op,
772 ByteCodeWriter &writer) {
776 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
777 writer.appendPDLValueList(op.getArgs());
778 writer.append(ByteCodeField(op.getIsNegated()));
780 writer.append(ByteCodeField(results.size()));
781 for (
Value result : results) {
785 writer.appendPDLValueKind(result);
788 if (isa<pdl::RangeType>(result.getType()))
789 writer.append(getRangeStorageIndex(result));
790 writer.append(result);
792 writer.append(op.getSuccessors());
794 void Generator::generate(pdl_interp::ApplyRewriteOp op,
795 ByteCodeWriter &writer) {
796 assert(externalRewriterToMemIndex.count(op.getName()) &&
797 "expected index for rewrite function");
798 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
799 writer.appendPDLValueList(op.getArgs());
802 writer.append(ByteCodeField(results.size()));
803 for (
Value result : results) {
806 writer.appendPDLValueKind(result);
809 if (isa<pdl::RangeType>(result.getType()))
810 writer.append(getRangeStorageIndex(result));
811 writer.append(result);
814 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
815 Value lhs = op.getLhs();
816 if (isa<pdl::RangeType>(lhs.
getType())) {
817 writer.append(OpCode::AreRangesEqual);
818 writer.appendPDLValueKind(lhs);
819 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
823 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
825 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
828 void Generator::generate(pdl_interp::CheckAttributeOp op,
829 ByteCodeWriter &writer) {
830 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
833 void Generator::generate(pdl_interp::CheckOperandCountOp op,
834 ByteCodeWriter &writer) {
835 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
836 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
839 void Generator::generate(pdl_interp::CheckOperationNameOp op,
840 ByteCodeWriter &writer) {
841 writer.append(OpCode::CheckOperationName, op.getInputOp(),
844 void Generator::generate(pdl_interp::CheckResultCountOp op,
845 ByteCodeWriter &writer) {
846 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
847 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
850 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
851 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
854 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
855 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
858 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
859 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
860 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
862 void Generator::generate(pdl_interp::CreateAttributeOp op,
863 ByteCodeWriter &writer) {
865 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
867 void Generator::generate(pdl_interp::CreateOperationOp op,
868 ByteCodeWriter &writer) {
869 writer.append(OpCode::CreateOperation, op.getResultOp(),
871 writer.appendPDLValueList(op.getInputOperands());
875 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
876 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
877 writer.append(std::get<0>(it), std::get<1>(it));
881 if (op.getInferredResultTypes())
884 writer.appendPDLValueList(op.getInputResultTypes());
886 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
890 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
891 .Case([&](pdl::ValueType) {
892 writer.append(OpCode::CreateDynamicValueRange);
895 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
896 writer.appendPDLValueList(op->getOperands());
898 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
900 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
902 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
903 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
904 getRangeStorageIndex(op.getResult()), op.getValue());
906 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
907 writer.append(OpCode::EraseOp, op.getInputOp());
909 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
912 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
913 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
914 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
915 .Default([](
Type) -> OpCode {
916 llvm_unreachable(
"unsupported element type");
918 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
920 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
921 writer.append(OpCode::Finalize);
923 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
925 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
926 writer.appendPDLValueKind(arg.
getType());
927 writer.append(curLoopLevel, op.getSuccessor());
929 if (curLoopLevel > maxLoopLevel)
930 maxLoopLevel = curLoopLevel;
931 generate(&op.getRegion(), writer);
934 void Generator::generate(pdl_interp::GetAttributeOp op,
935 ByteCodeWriter &writer) {
936 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
939 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
940 ByteCodeWriter &writer) {
941 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
943 void Generator::generate(pdl_interp::GetDefiningOpOp op,
944 ByteCodeWriter &writer) {
945 writer.append(OpCode::GetDefiningOp, op.getInputOp());
946 writer.appendPDLValue(op.getValue());
948 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
949 uint32_t index = op.getIndex();
951 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
953 writer.append(OpCode::GetOperandN, index);
954 writer.append(op.getInputOp(), op.getValue());
956 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
957 Value result = op.getValue();
958 std::optional<uint32_t> index = op.getIndex();
959 writer.append(OpCode::GetOperands,
962 if (isa<pdl::RangeType>(result.
getType()))
963 writer.append(getRangeStorageIndex(result));
966 writer.append(result);
968 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
969 uint32_t index = op.getIndex();
971 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
973 writer.append(OpCode::GetResultN, index);
974 writer.append(op.getInputOp(), op.getValue());
976 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
977 Value result = op.getValue();
978 std::optional<uint32_t> index = op.getIndex();
979 writer.append(OpCode::GetResults,
982 if (isa<pdl::RangeType>(result.
getType()))
983 writer.append(getRangeStorageIndex(result));
986 writer.append(result);
988 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
989 Value operations = op.getOperations();
990 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
991 writer.append(OpCode::GetUsers, operations, rangeIndex);
992 writer.appendPDLValue(op.getValue());
994 void Generator::generate(pdl_interp::GetValueTypeOp op,
995 ByteCodeWriter &writer) {
996 if (isa<pdl::RangeType>(op.getType())) {
997 Value result = op.getResult();
998 writer.append(OpCode::GetValueRangeTypes, result,
999 getRangeStorageIndex(result), op.getValue());
1001 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1004 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1005 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1007 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1008 ByteCodeField patternIndex =
patterns.size();
1009 patterns.emplace_back(PDLByteCodePattern::create(
1010 op, configMap.lookup(op),
1011 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1012 writer.append(OpCode::RecordMatch, patternIndex,
1014 writer.appendPDLValueList(op.getInputs());
1016 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1017 writer.append(OpCode::ReplaceOp, op.getInputOp());
1018 writer.appendPDLValueList(op.getReplValues());
1020 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1021 ByteCodeWriter &writer) {
1022 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1023 op.getCaseValuesAttr(), op.getSuccessors());
1025 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1026 ByteCodeWriter &writer) {
1027 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1028 op.getCaseValuesAttr(), op.getSuccessors());
1030 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1031 ByteCodeWriter &writer) {
1032 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1033 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1035 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1036 op.getSuccessors());
1038 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1039 ByteCodeWriter &writer) {
1040 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1041 op.getCaseValuesAttr(), op.getSuccessors());
1043 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1044 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1045 op.getSuccessors());
1047 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1048 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1049 op.getSuccessors());
1056 PDLByteCode::PDLByteCode(
1057 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1059 llvm::StringMap<PDLConstraintFunction> constraintFns,
1060 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1061 : configs(std::move(configs)) {
1062 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1063 rewriterByteCode,
patterns, maxValueMemoryIndex,
1064 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1065 maxLoopLevel, constraintFns, rewriteFns, configMap);
1069 for (
auto &it : constraintFns)
1070 constraintFunctions.push_back(std::move(it.second));
1071 for (
auto &it : rewriteFns)
1072 rewriteFunctions.push_back(std::move(it.second));
1078 state.memory.resize(maxValueMemoryIndex,
nullptr);
1079 state.opRangeMemory.resize(maxOpRangeCount);
1080 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1081 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1082 state.loopIndex.resize(maxLoopLevel, 0);
1083 state.currentPatternBenefits.reserve(
patterns.size());
1085 state.currentPatternBenefits.push_back(pattern.getBenefit());
1096 class ByteCodeRewriteResultList :
public PDLResultList {
1098 ByteCodeRewriteResultList(
unsigned maxNumResults)
1099 : PDLResultList(maxNumResults) {}
1106 return allocatedTypeRanges;
1111 return allocatedValueRanges;
1116 class ByteCodeExecutor {
1122 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1124 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1131 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1132 typeRangeMemory(typeRangeMemory),
1133 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1134 valueRangeMemory(valueRangeMemory),
1135 allocatedValueRangeMemory(allocatedValueRangeMemory),
1136 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1138 constraintFunctions(constraintFunctions),
1139 rewriteFunctions(rewriteFunctions) {}
1147 std::optional<Location> mainRewriteLoc = {});
1153 void executeAreEqual();
1154 void executeAreRangesEqual();
1155 void executeBranch();
1156 void executeCheckOperandCount();
1157 void executeCheckOperationName();
1158 void executeCheckResultCount();
1159 void executeCheckTypes();
1160 void executeContinue();
1161 void executeCreateConstantTypeRange();
1164 template <
typename T>
1165 void executeDynamicCreateRange(StringRef type);
1167 template <
typename T,
typename Range, PDLValue::Kind kind>
1168 void executeExtract();
1169 void executeFinalize();
1170 void executeForEach();
1171 void executeGetAttribute();
1172 void executeGetAttributeType();
1173 void executeGetDefiningOp();
1174 void executeGetOperand(
unsigned index);
1175 void executeGetOperands();
1176 void executeGetResult(
unsigned index);
1177 void executeGetResults();
1178 void executeGetUsers();
1179 void executeGetValueType();
1180 void executeGetValueRangeTypes();
1181 void executeIsNotNull();
1185 void executeSwitchAttribute();
1186 void executeSwitchOperandCount();
1187 void executeSwitchOperationName();
1188 void executeSwitchResultCount();
1189 void executeSwitchType();
1190 void executeSwitchTypes();
1191 void processNativeFunResults(ByteCodeRewriteResultList &results,
1192 unsigned numResults,
1193 LogicalResult &rewriteResult);
1196 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1200 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1201 curCodeIt = resumeCodeIt.pop_back_val();
1205 const ByteCodeField *getPrevCodeIt()
const {
1208 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1212 return curCodeIt - 1;
1218 template <
typename T = ByteCodeField>
1219 T read(
size_t skipN = 0) {
1221 return readImpl<T>();
1223 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1226 template <
typename ValueT,
typename T>
1229 for (
unsigned i = 0, e = read(); i != e; ++i)
1230 list.push_back(read<ValueT>());
1236 for (
unsigned i = 0, e = read(); i != e; ++i) {
1237 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1238 list.push_back(read<Type>());
1240 TypeRange *values = read<TypeRange *>();
1241 list.append(values->begin(), values->end());
1246 for (
unsigned i = 0, e = read(); i != e; ++i) {
1247 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1248 list.push_back(read<Value>());
1251 list.append(values->begin(), values->end());
1257 template <
typename T>
1258 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1260 const void *pointer;
1261 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1262 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1263 return T::getFromOpaquePointer(pointer);
1266 void skip(
size_t skipN) { curCodeIt += skipN; }
1269 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1271 void selectJump(
size_t destIndex) {
1272 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1276 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1277 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1279 llvm::dbgs() <<
" * Value: " << value <<
"\n"
1281 llvm::interleaveComma(cases, llvm::dbgs());
1282 llvm::dbgs() <<
"\n";
1287 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1288 if (cmp(*it, value))
1289 return selectJump(
size_t((it - cases.begin()) + 1));
1290 selectJump(
size_t(0));
1294 void storeToMemory(
unsigned index,
const void *value) {
1295 memory[index] = value;
1299 template <
typename T>
1300 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1301 storeToMemory(
unsigned index, T value) {
1302 memory[index] = value.getAsOpaquePointer();
1307 template <
typename T>
1308 const void *readFromMemory() {
1309 size_t index = *curCodeIt++;
1314 index < memory.size())
1315 return memory[index];
1318 return uniquedMemory[index - memory.size()];
1320 template <
typename T>
1321 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1322 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1324 template <
typename T>
1325 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1328 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1330 template <
typename T>
1331 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1332 switch (read<PDLValue::Kind>()) {
1333 case PDLValue::Kind::Attribute:
1334 return read<Attribute>();
1335 case PDLValue::Kind::Operation:
1336 return read<Operation *>();
1337 case PDLValue::Kind::Type:
1338 return read<Type>();
1339 case PDLValue::Kind::Value:
1340 return read<Value>();
1341 case PDLValue::Kind::TypeRange:
1342 return read<TypeRange *>();
1343 case PDLValue::Kind::ValueRange:
1344 return read<ValueRange *>();
1346 llvm_unreachable(
"unhandled PDLValue::Kind");
1348 template <
typename T>
1349 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1350 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1351 "unexpected ByteCode address size");
1352 ByteCodeAddr result;
1353 std::memcpy(&result, curCodeIt,
sizeof(ByteCodeAddr));
1357 template <
typename T>
1358 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1359 return *curCodeIt++;
1361 template <
typename T>
1362 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1368 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1369 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1370 unsigned rangeIndex) {
1372 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1374 if (range.empty()) {
1375 rangeMemory[rangeIndex] = {};
1378 llvm::OwningArrayRef<T> storage(llvm::size(range));
1383 allocatedRangeMemory.emplace_back(std::move(storage));
1384 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1386 memory[memIndex] = &rangeMemory[rangeIndex];
1390 if constexpr (std::is_same_v<T, Type>) {
1391 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1392 }
else if constexpr (std::is_same_v<T, Value>) {
1393 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1395 llvm_unreachable(
"unhandled range type");
1400 const ByteCodeField *curCodeIt;
1409 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1411 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1426 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1427 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1428 ByteCodeField fun_idx = read();
1430 readList<PDLValue>(args);
1433 llvm::dbgs() <<
" * Arguments: ";
1434 llvm::interleaveComma(args, llvm::dbgs());
1435 llvm::dbgs() <<
"\n";
1438 ByteCodeField isNegated = read();
1440 llvm::dbgs() <<
" * isNegated: " << isNegated <<
"\n";
1441 llvm::interleaveComma(args, llvm::dbgs());
1444 ByteCodeField numResults = read();
1445 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1446 ByteCodeRewriteResultList results(numResults);
1447 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1450 if (succeeded(rewriteResult)) {
1451 llvm::dbgs() <<
" * Constraint succeeded\n";
1452 llvm::dbgs() <<
" * Results: ";
1453 llvm::interleaveComma(constraintResults, llvm::dbgs());
1454 llvm::dbgs() <<
"\n";
1456 llvm::dbgs() <<
" * Constraint failed\n";
1459 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1460 "native PDL rewrite function succeeded but returned "
1461 "unexpected number of results");
1462 processNativeFunResults(results, numResults, rewriteResult);
1465 selectJump(isNegated != succeeded(rewriteResult));
1468 LogicalResult ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1469 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1470 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1472 readList<PDLValue>(args);
1475 llvm::dbgs() <<
" * Arguments: ";
1476 llvm::interleaveComma(args, llvm::dbgs());
1480 ByteCodeField numResults = read();
1481 ByteCodeRewriteResultList results(numResults);
1482 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1484 assert(results.getResults().size() == numResults &&
1485 "native PDL rewrite function returned unexpected number of results");
1487 processNativeFunResults(results, numResults, rewriteResult);
1489 if (failed(rewriteResult)) {
1490 LLVM_DEBUG(llvm::dbgs() <<
" - Failed");
1496 void ByteCodeExecutor::processNativeFunResults(
1497 ByteCodeRewriteResultList &results,
unsigned numResults,
1498 LogicalResult &rewriteResult) {
1501 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1506 if (failed(rewriteResult)) {
1507 if (resultKind == PDLValue::Kind::TypeRange ||
1508 resultKind == PDLValue::Kind::ValueRange) {
1515 PDLValue result = results.getResults()[resultIdx];
1516 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1517 assert(result.getKind() == resultKind &&
1518 "native PDL rewrite function returned an unexpected type of "
1522 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1523 unsigned rangeIndex = read();
1524 typeRangeMemory[rangeIndex] = *typeRange;
1525 memory[read()] = &typeRangeMemory[rangeIndex];
1526 }
else if (std::optional<ValueRange> valueRange =
1528 unsigned rangeIndex = read();
1529 valueRangeMemory[rangeIndex] = *valueRange;
1530 memory[read()] = &valueRangeMemory[rangeIndex];
1532 memory[read()] = result.getAsOpaquePointer();
1537 for (
auto &it : results.getAllocatedTypeRanges())
1538 allocatedTypeRangeMemory.push_back(std::move(it));
1539 for (
auto &it : results.getAllocatedValueRanges())
1540 allocatedValueRangeMemory.push_back(std::move(it));
1543 void ByteCodeExecutor::executeAreEqual() {
1544 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1545 const void *lhs = read<const void *>();
1546 const void *rhs = read<const void *>();
1548 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1549 selectJump(lhs == rhs);
1552 void ByteCodeExecutor::executeAreRangesEqual() {
1553 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1555 const void *lhs = read<const void *>();
1556 const void *rhs = read<const void *>();
1558 switch (valueKind) {
1559 case PDLValue::Kind::TypeRange: {
1562 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1563 selectJump(*lhsRange == *rhsRange);
1566 case PDLValue::Kind::ValueRange: {
1567 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1568 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1569 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1570 selectJump(*lhsRange == *rhsRange);
1574 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1578 void ByteCodeExecutor::executeBranch() {
1579 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1580 curCodeIt = &code[read<ByteCodeAddr>()];
1583 void ByteCodeExecutor::executeCheckOperandCount() {
1584 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1586 uint32_t expectedCount = read<uint32_t>();
1587 bool compareAtLeast = read();
1589 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n"
1590 <<
" * Expected: " << expectedCount <<
"\n"
1591 <<
" * Comparator: "
1592 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1599 void ByteCodeExecutor::executeCheckOperationName() {
1600 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1604 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n"
1605 <<
" * Expected: \"" << expectedName <<
"\"\n");
1606 selectJump(op->
getName() == expectedName);
1609 void ByteCodeExecutor::executeCheckResultCount() {
1610 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1612 uint32_t expectedCount = read<uint32_t>();
1613 bool compareAtLeast = read();
1615 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n"
1616 <<
" * Expected: " << expectedCount <<
"\n"
1617 <<
" * Comparator: "
1618 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1625 void ByteCodeExecutor::executeCheckTypes() {
1626 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1629 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1631 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1634 void ByteCodeExecutor::executeContinue() {
1635 ByteCodeField level = read();
1636 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n"
1637 <<
" * Level: " << level <<
"\n");
1642 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1643 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateConstantTypeRange:\n");
1644 unsigned memIndex = read();
1645 unsigned rangeIndex = read();
1646 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1648 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1649 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1653 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1655 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1657 unsigned memIndex = read();
1659 readList(state.operands);
1660 for (
unsigned i = 0, e = read(); i != e; ++i) {
1661 StringAttr name = read<StringAttr>();
1663 state.addAttribute(name, attr);
1668 unsigned numResults = read();
1670 InferTypeOpInterface::Concept *inferInterface =
1671 state.name.getInterface<InferTypeOpInterface>();
1672 assert(inferInterface &&
1673 "expected operation to provide InferTypeOpInterface");
1676 if (failed(inferInterface->inferReturnTypes(
1677 state.getContext(), state.location, state.operands,
1678 state.attributes.getDictionary(state.getContext()),
1679 state.getRawProperties(), state.regions, state.types)))
1683 for (
unsigned i = 0; i != numResults; ++i) {
1684 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1685 state.types.push_back(read<Type>());
1687 TypeRange *resultTypes = read<TypeRange *>();
1688 state.types.append(resultTypes->begin(), resultTypes->end());
1694 memory[memIndex] = resultOp;
1697 llvm::dbgs() <<
" * Attributes: "
1698 << state.attributes.getDictionary(state.getContext())
1699 <<
"\n * Operands: ";
1700 llvm::interleaveComma(state.operands, llvm::dbgs());
1701 llvm::dbgs() <<
"\n * Result Types: ";
1702 llvm::interleaveComma(state.types, llvm::dbgs());
1703 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1707 template <
typename T>
1708 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1709 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateDynamic" << type <<
"Range:\n");
1710 unsigned memIndex = read();
1711 unsigned rangeIndex = read();
1716 llvm::dbgs() <<
"\n * " << type <<
"s: ";
1717 llvm::interleaveComma(values, llvm::dbgs());
1718 llvm::dbgs() <<
"\n";
1721 assignRangeToMemory(values, memIndex, rangeIndex);
1725 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1728 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1732 template <
typename T,
typename Range, PDLValue::Kind kind>
1733 void ByteCodeExecutor::executeExtract() {
1734 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" <<
kind <<
":\n");
1735 Range *range = read<Range *>();
1736 unsigned index = read<uint32_t>();
1737 unsigned memIndex = read();
1740 memory[memIndex] =
nullptr;
1744 T result = index < range->
size() ? (*range)[index] : T();
1745 LLVM_DEBUG(llvm::dbgs() <<
" * " <<
kind <<
"s(" << range->
size() <<
")\n"
1746 <<
" * Index: " << index <<
"\n"
1747 <<
" * Result: " << result <<
"\n");
1748 storeToMemory(memIndex, result);
1751 void ByteCodeExecutor::executeFinalize() {
1752 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1755 void ByteCodeExecutor::executeForEach() {
1756 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1757 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1758 unsigned rangeIndex = read();
1759 unsigned memIndex = read();
1760 const void *value =
nullptr;
1762 switch (read<PDLValue::Kind>()) {
1763 case PDLValue::Kind::Operation: {
1764 unsigned &index = loopIndex[read()];
1766 assert(index <= array.size() &&
"iterated past the end");
1767 if (index < array.size()) {
1768 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1769 value = array[index];
1773 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1775 selectJump(
size_t(0));
1779 llvm_unreachable(
"unexpected `ForEach` value kind");
1783 memory[memIndex] = value;
1784 pushCodeIt(prevCodeIt);
1787 read<ByteCodeAddr>();
1790 void ByteCodeExecutor::executeGetAttribute() {
1791 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1792 unsigned memIndex = read();
1794 StringAttr attrName = read<StringAttr>();
1797 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1798 <<
" * Attribute: " << attrName <<
"\n"
1799 <<
" * Result: " << attr <<
"\n");
1803 void ByteCodeExecutor::executeGetAttributeType() {
1804 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1805 unsigned memIndex = read();
1808 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1809 type = typedAttr.getType();
1811 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n"
1812 <<
" * Result: " << type <<
"\n");
1816 void ByteCodeExecutor::executeGetDefiningOp() {
1817 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1818 unsigned memIndex = read();
1820 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1821 Value value = read<Value>();
1824 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1827 if (values && !values->empty()) {
1828 op = values->front().getDefiningOp();
1830 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1833 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1834 memory[memIndex] = op;
1837 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1839 unsigned memIndex = read();
1843 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1844 <<
" * Index: " << index <<
"\n"
1845 <<
" * Result: " << operand <<
"\n");
1852 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1855 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1860 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1864 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1865 LLVM_DEBUG(llvm::dbgs()
1866 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1869 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1873 unsigned startIndex =
1874 std::accumulate(segments.begin(), segments.begin() + index, 0);
1875 values = values.slice(startIndex, *std::next(segments.begin(), index));
1877 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", "
1878 << *std::next(segments.begin(), index) <<
"]\n");
1884 }
else if (values.size() >= index) {
1885 LLVM_DEBUG(llvm::dbgs()
1886 <<
" * Treating values as trailing variadic range\n");
1887 values = values.drop_front(index);
1896 valueRangeMemory[rangeIndex] = values;
1897 return &valueRangeMemory[rangeIndex];
1901 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1904 void ByteCodeExecutor::executeGetOperands() {
1905 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1906 unsigned index = read<uint32_t>();
1908 ByteCodeField rangeIndex = read();
1910 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1911 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1914 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1915 memory[read()] = result;
1918 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1920 unsigned memIndex = read();
1924 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1925 <<
" * Index: " << index <<
"\n"
1926 <<
" * Result: " << result <<
"\n");
1930 void ByteCodeExecutor::executeGetResults() {
1931 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1932 unsigned index = read<uint32_t>();
1934 ByteCodeField rangeIndex = read();
1936 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1937 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1940 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1941 memory[read()] = result;
1944 void ByteCodeExecutor::executeGetUsers() {
1945 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1946 unsigned memIndex = read();
1947 unsigned rangeIndex = read();
1948 OwningOpRange &range = opRangeMemory[rangeIndex];
1949 memory[memIndex] = ⦥
1951 range = OwningOpRange();
1952 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1954 Value value = read<Value>();
1957 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1968 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1969 llvm::interleaveComma(*values, llvm::dbgs());
1970 llvm::dbgs() <<
"\n";
1975 for (
Value value : *values)
1977 range = OwningOpRange(users.size());
1981 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1984 void ByteCodeExecutor::executeGetValueType() {
1985 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1986 unsigned memIndex = read();
1987 Value value = read<Value>();
1990 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n"
1991 <<
" * Result: " << type <<
"\n");
1995 void ByteCodeExecutor::executeGetValueRangeTypes() {
1996 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
1997 unsigned memIndex = read();
1998 unsigned rangeIndex = read();
2001 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
2002 memory[memIndex] =
nullptr;
2007 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
2008 llvm::interleaveComma(*values, llvm::dbgs());
2009 llvm::dbgs() <<
"\n * Result: ";
2010 llvm::interleaveComma(values->
getType(), llvm::dbgs());
2011 llvm::dbgs() <<
"\n";
2013 typeRangeMemory[rangeIndex] = values->
getType();
2014 memory[memIndex] = &typeRangeMemory[rangeIndex];
2017 void ByteCodeExecutor::executeIsNotNull() {
2018 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
2019 const void *value = read<const void *>();
2021 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
2022 selectJump(value !=
nullptr);
2025 void ByteCodeExecutor::executeRecordMatch(
2028 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
2029 unsigned patternIndex = read();
2031 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2036 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
2045 unsigned numMatchLocs = read();
2047 matchLocs.reserve(numMatchLocs);
2048 for (
unsigned i = 0; i != numMatchLocs; ++i)
2049 matchLocs.push_back(read<Operation *>()->getLoc());
2052 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n"
2053 <<
" * Location: " << matchLoc <<
"\n");
2054 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2060 unsigned numInputs = read();
2061 match.values.reserve(numInputs);
2062 match.typeRangeValues.reserve(numInputs);
2063 match.valueRangeValues.reserve(numInputs);
2064 for (
unsigned i = 0; i < numInputs; ++i) {
2065 switch (read<PDLValue::Kind>()) {
2066 case PDLValue::Kind::TypeRange:
2067 match.typeRangeValues.push_back(*read<TypeRange *>());
2068 match.values.push_back(&match.typeRangeValues.back());
2070 case PDLValue::Kind::ValueRange:
2071 match.valueRangeValues.push_back(*read<ValueRange *>());
2072 match.values.push_back(&match.valueRangeValues.back());
2075 match.values.push_back(read<const void *>());
2083 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
2089 llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
2091 llvm::interleaveComma(args, llvm::dbgs());
2092 llvm::dbgs() <<
"\n";
2097 void ByteCodeExecutor::executeSwitchAttribute() {
2098 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
2100 ArrayAttr cases = read<ArrayAttr>();
2101 handleSwitch(value, cases);
2104 void ByteCodeExecutor::executeSwitchOperandCount() {
2105 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
2107 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2109 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2113 void ByteCodeExecutor::executeSwitchOperationName() {
2114 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
2116 size_t caseCount = read();
2122 const ByteCodeField *prevCodeIt = curCodeIt;
2123 llvm::dbgs() <<
" * Value: " << value <<
"\n"
2125 llvm::interleaveComma(
2126 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2127 [&](
size_t) {
return read<OperationName>(); }),
2129 llvm::dbgs() <<
"\n";
2130 curCodeIt = prevCodeIt;
2134 for (
size_t i = 0; i != caseCount; ++i) {
2135 if (read<OperationName>() == value) {
2136 curCodeIt += (caseCount - i - 1);
2137 return selectJump(i + 1);
2140 selectJump(
size_t(0));
2143 void ByteCodeExecutor::executeSwitchResultCount() {
2144 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
2146 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2148 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2152 void ByteCodeExecutor::executeSwitchType() {
2153 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2154 Type value = read<Type>();
2155 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2156 handleSwitch(value, cases);
2159 void ByteCodeExecutor::executeSwitchTypes() {
2160 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2162 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2164 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2165 return selectJump(
size_t(0));
2167 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2168 return value == caseValue.getAsValueRange<TypeAttr>();
2175 std::optional<Location> mainRewriteLoc) {
2178 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2180 OpCode opCode =
static_cast<OpCode
>(read());
2182 case ApplyConstraint:
2183 executeApplyConstraint(rewriter);
2186 if (failed(executeApplyRewrite(rewriter)))
2192 case AreRangesEqual:
2193 executeAreRangesEqual();
2198 case CheckOperandCount:
2199 executeCheckOperandCount();
2201 case CheckOperationName:
2202 executeCheckOperationName();
2204 case CheckResultCount:
2205 executeCheckResultCount();
2208 executeCheckTypes();
2213 case CreateConstantTypeRange:
2214 executeCreateConstantTypeRange();
2216 case CreateOperation:
2217 executeCreateOperation(rewriter, *mainRewriteLoc);
2219 case CreateDynamicTypeRange:
2220 executeDynamicCreateRange<Type>(
"Type");
2222 case CreateDynamicValueRange:
2223 executeDynamicCreateRange<Value>(
"Value");
2226 executeEraseOp(rewriter);
2229 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2232 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2235 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2239 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2245 executeGetAttribute();
2247 case GetAttributeType:
2248 executeGetAttributeType();
2251 executeGetDefiningOp();
2257 unsigned index = opCode - GetOperand0;
2258 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2259 executeGetOperand(index);
2263 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2264 executeGetOperand(read<uint32_t>());
2267 executeGetOperands();
2273 unsigned index = opCode - GetResult0;
2274 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2275 executeGetResult(index);
2279 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2280 executeGetResult(read<uint32_t>());
2283 executeGetResults();
2289 executeGetValueType();
2291 case GetValueRangeTypes:
2292 executeGetValueRangeTypes();
2299 "expected matches to be provided when executing the matcher");
2300 executeRecordMatch(rewriter, *matches);
2303 executeReplaceOp(rewriter);
2305 case SwitchAttribute:
2306 executeSwitchAttribute();
2308 case SwitchOperandCount:
2309 executeSwitchOperandCount();
2311 case SwitchOperationName:
2312 executeSwitchOperationName();
2314 case SwitchResultCount:
2315 executeSwitchResultCount();
2318 executeSwitchType();
2321 executeSwitchTypes();
2324 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2332 state.memory[0] = op;
2335 ByteCodeExecutor executor(
2336 matcherByteCode.data(), state.memory, state.opRangeMemory,
2337 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2338 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2339 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2340 constraintFunctions, rewriteFunctions);
2341 LogicalResult executeResult = executor.execute(rewriter, &matches);
2342 (void)executeResult;
2343 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2346 std::stable_sort(matches.begin(), matches.end(),
2347 [](
const MatchResult &lhs,
const MatchResult &rhs) {
2348 return lhs.benefit > rhs.benefit;
2353 const MatchResult &match,
2355 auto *configSet =
match.pattern->getConfigSet();
2357 configSet->notifyRewriteBegin(rewriter);
2363 ByteCodeExecutor executor(
2364 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2365 state.opRangeMemory, state.typeRangeMemory,
2366 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2367 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2368 rewriterByteCode, state.currentPatternBenefits,
patterns,
2369 constraintFunctions, rewriteFunctions);
2370 LogicalResult result =
2371 executor.execute(rewriter,
nullptr,
match.location);
2374 configSet->notifyRewriteEnd(rewriter);
2383 LLVM_DEBUG(llvm::dbgs() <<
" and rollback is not supported - aborting");
2384 llvm::report_fatal_error(
2385 "Native PDL Rewrite failed, but the pattern "
2386 "rewriter doesn't support recovery. Failable pattern rewrites should "
2387 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1193::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...