19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 27 #define DEBUG_TYPE "pdl-bytecode" 39 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
63 currentPatternBenefits[patternIndex] = benefit;
70 allocatedTypeRangeMemory.clear();
71 allocatedValueRangeMemory.clear();
177 struct ByteCodeLiveRange;
178 struct ByteCodeWriter;
181 template <
typename T,
typename... Args>
182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
187 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
196 llvm::StringMap<PDLConstraintFunction> &constraintFns,
197 llvm::StringMap<PDLRewriteFunction> &rewriteFns)
198 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
199 rewriterByteCode(rewriterByteCode), patterns(patterns),
200 maxValueMemoryIndex(maxValueMemoryIndex),
201 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
202 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
203 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
204 maxLoopLevel(maxLoopLevel) {
206 constraintToMemIndex.try_emplace(it.value().first(), it.index());
208 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
212 void generate(ModuleOp module);
216 assert(valueToMemIndex.count(value) &&
217 "expected memory index to be assigned");
218 return valueToMemIndex[
value];
223 assert(valueToRangeIndex.count(value) &&
224 "expected range index to be assigned");
225 return valueToRangeIndex[
value];
230 template <
typename T>
233 const void *opaqueVal = val.getAsOpaquePointer();
236 auto it = uniquedDataToMemIndex.try_emplace(
237 opaqueVal, maxValueMemoryIndex + uniquedData.size());
239 uniquedData.push_back(opaqueVal);
240 return it.first->second;
246 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
247 ModuleOp rewriterModule);
250 void generate(
Region *region, ByteCodeWriter &writer);
251 void generate(
Operation *op, ByteCodeWriter &writer);
252 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
253 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
254 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
255 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
256 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
257 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
258 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
259 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
260 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
302 llvm::StringMap<ByteCodeField> constraintToMemIndex;
306 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
322 std::vector<const void *> &uniquedData;
334 struct ByteCodeWriter {
336 : bytecode(bytecode), generator(generator) {}
339 void append(
ByteCodeField field) { bytecode.push_back(field); }
340 void append(
OpCode opCode) { bytecode.push_back(opCode); }
345 "unexpected ByteCode address size");
349 bytecode.append({fieldParts[0], fieldParts[1]});
354 void append(
Block *successor) {
357 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
364 for (
Block *successor : successors)
370 bytecode.push_back(values.size());
371 for (
Value value : values)
372 appendPDLValue(value);
376 void appendPDLValue(
Value value) {
377 appendPDLValueKind(value);
382 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
385 void appendPDLValueKind(
Type type) {
388 .Case<pdl::AttributeType>(
390 .Case<pdl::OperationType>(
392 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
393 if (rangeTy.getElementType().isa<pdl::TypeType>())
399 bytecode.push_back(static_cast<ByteCodeField>(kind));
404 template <
typename T>
408 bytecode.push_back(
generator.getMemIndex(value));
412 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
415 bytecode.push_back(llvm::size(range));
416 for (
auto it : range)
421 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
422 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
424 append(field2, fields...);
428 template <
typename T>
430 appendInline(T value) {
431 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(
ByteCodeField);
432 const void *pointer = value.getAsOpaquePointer();
434 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
435 bytecode.append(fieldParts, fieldParts + numParts);
450 struct ByteCodeLiveRange {
451 using Set = llvm::IntervalMap<uint64_t, char, 16>;
452 using Allocator = Set::Allocator;
454 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
457 void unionWith(
const ByteCodeLiveRange &rhs) {
458 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
460 liveness->insert(it.start(), it.stop(), 0);
464 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
465 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
475 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 void Generator::generate(ModuleOp module) {
489 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
490 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
491 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
492 pdl_interp::PDLInterpDialect::getRewriterModuleName());
493 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
497 allocateMemoryIndices(matcherFunc, rewriterModule);
500 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
501 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
502 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
503 for (
Operation &op : rewriterFunc.getOps())
504 generate(&op, rewriterByteCodeWriter);
506 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
507 "unexpected branches in rewriter function");
510 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
511 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
514 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
516 for (
unsigned offsetToFix : it.second)
517 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(
ByteCodeAddr));
521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
522 ModuleOp rewriterModule) {
525 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
526 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
527 auto processRewriterValue = [&](
Value val) {
528 valueToMemIndex.try_emplace(val, index++);
529 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
530 Type elementTy = rangeType.getElementType();
531 if (elementTy.
isa<pdl::TypeType>())
532 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
533 else if (elementTy.
isa<pdl::ValueType>())
534 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
539 processRewriterValue(arg);
540 rewriterFunc.getBody().walk([&](
Operation *op) {
542 processRewriterValue(result);
544 if (index > maxValueMemoryIndex)
545 maxValueMemoryIndex = index;
546 if (typeRangeIndex > maxTypeRangeMemoryIndex)
547 maxTypeRangeMemoryIndex = typeRangeIndex;
548 if (valueRangeIndex > maxValueRangeMemoryIndex)
549 maxValueRangeMemoryIndex = valueRangeIndex;
565 llvm::unique_function<void(Operation *)>
walk = [&](
Operation *op) {
566 opToFirstIndex.try_emplace(op, index++);
568 for (
Block &block : region.getBlocks())
571 opToLastIndex.try_emplace(op, index++);
576 ByteCodeLiveRange::Allocator allocator;
581 valueToMemIndex[rootOpArg] = 0;
584 Liveness matcherLiveness(matcherFunc);
585 matcherFunc->walk([&](
Block *block) {
587 assert(info &&
"expected liveness info for block");
591 if (value == rootOpArg)
595 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
596 defRangeIt->second.liveness->insert(
597 opToFirstIndex[firstUseOrDef],
603 Type eleType = rangeTy.getElementType();
604 if (eleType.
isa<pdl::OperationType>())
605 defRangeIt->second.opRangeIndex = 0;
606 else if (eleType.
isa<pdl::TypeType>())
607 defRangeIt->second.typeRangeIndex = 0;
608 else if (eleType.
isa<pdl::ValueType>())
609 defRangeIt->second.valueRangeIndex = 0;
614 for (
Value liveIn : info->
in()) {
619 if (liveIn.getParentRegion() == block->
getParent())
636 std::vector<ByteCodeLiveRange> allocatedIndices;
643 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
645 for (
auto &defIt : valueDefRanges) {
647 ByteCodeLiveRange &defRange = defIt.second;
650 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
651 ByteCodeLiveRange &existingRange = existingIndexIt.value();
652 if (!defRange.overlaps(existingRange)) {
653 existingRange.unionWith(defRange);
654 memIndex = existingIndexIt.index() + 1;
656 if (defRange.opRangeIndex) {
657 if (!existingRange.opRangeIndex)
658 existingRange.opRangeIndex = numOpRanges++;
659 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
660 }
else if (defRange.typeRangeIndex) {
661 if (!existingRange.typeRangeIndex)
662 existingRange.typeRangeIndex = numTypeRanges++;
663 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
664 }
else if (defRange.valueRangeIndex) {
665 if (!existingRange.valueRangeIndex)
666 existingRange.valueRangeIndex = numValueRanges++;
667 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
675 allocatedIndices.emplace_back(allocator);
676 ByteCodeLiveRange &newRange = allocatedIndices.back();
677 newRange.unionWith(defRange);
680 if (defRange.opRangeIndex) {
681 newRange.opRangeIndex = numOpRanges;
682 valueToRangeIndex[defIt.first] = numOpRanges++;
683 }
else if (defRange.typeRangeIndex) {
684 newRange.typeRangeIndex = numTypeRanges;
685 valueToRangeIndex[defIt.first] = numTypeRanges++;
686 }
else if (defRange.valueRangeIndex) {
687 newRange.valueRangeIndex = numValueRanges;
688 valueToRangeIndex[defIt.first] = numValueRanges++;
691 memIndex = allocatedIndices.size();
698 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices " 699 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
702 "Ran out of memory for allocated indices");
705 if (numIndices > maxValueMemoryIndex)
706 maxValueMemoryIndex = numIndices;
707 if (numOpRanges > maxOpRangeMemoryIndex)
708 maxOpRangeMemoryIndex = numOpRanges;
709 if (numTypeRanges > maxTypeRangeMemoryIndex)
710 maxTypeRangeMemoryIndex = numTypeRanges;
711 if (numValueRanges > maxValueRangeMemoryIndex)
712 maxValueRangeMemoryIndex = numValueRanges;
715 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
716 llvm::ReversePostOrderTraversal<Region *> rpot(region);
717 for (
Block *block : rpot) {
719 blockToAddr.try_emplace(block, matcherByteCode.size());
721 generate(&op, writer);
725 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
729 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
730 writer.appendInline(op->
getLoc());
733 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
734 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
735 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
736 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
737 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
738 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
739 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
740 pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
741 pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
742 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
743 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
744 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
745 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
746 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
747 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
748 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
749 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
750 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
751 pdl_interp::SwitchResultCountOp>(
752 [&](
auto interpOp) { this->generate(interpOp, writer); })
754 llvm_unreachable(
"unknown `pdl_interp` operation");
758 void Generator::generate(pdl_interp::ApplyConstraintOp op,
759 ByteCodeWriter &writer) {
760 assert(constraintToMemIndex.count(op.getName()) &&
761 "expected index for constraint function");
762 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
763 writer.appendPDLValueList(op.getArgs());
764 writer.append(op.getSuccessors());
766 void Generator::generate(pdl_interp::ApplyRewriteOp op,
767 ByteCodeWriter &writer) {
768 assert(externalRewriterToMemIndex.count(op.getName()) &&
769 "expected index for rewrite function");
770 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
771 writer.appendPDLValueList(op.getArgs());
775 for (
Value result : results) {
779 writer.appendPDLValueKind(result);
783 if (result.getType().isa<pdl::RangeType>())
784 writer.append(getRangeStorageIndex(result));
785 writer.append(result);
788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789 Value lhs = op.getLhs();
791 writer.append(OpCode::AreRangesEqual);
792 writer.appendPDLValueKind(lhs);
793 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
797 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803 ByteCodeWriter &writer) {
804 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808 ByteCodeWriter &writer) {
809 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814 ByteCodeWriter &writer) {
815 writer.append(OpCode::CheckOperationName, op.getInputOp(),
818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819 ByteCodeWriter &writer) {
820 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
829 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
833 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
834 writer.append(OpCode::Continue,
ByteCodeField(curLoopLevel - 1));
836 void Generator::generate(pdl_interp::CreateAttributeOp op,
837 ByteCodeWriter &writer) {
839 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
841 void Generator::generate(pdl_interp::CreateOperationOp op,
842 ByteCodeWriter &writer) {
843 writer.append(OpCode::CreateOperation, op.getResultOp(),
845 writer.appendPDLValueList(op.getInputOperands());
849 writer.append(static_cast<ByteCodeField>(attributes.size()));
850 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851 writer.append(std::get<0>(it), std::get<1>(it));
855 if (op.getInferredResultTypes())
858 writer.appendPDLValueList(op.getInputResultTypes());
860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
862 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
865 writer.append(OpCode::CreateTypes, op.getResult(),
866 getRangeStorageIndex(op.getResult()), op.getValue());
868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
869 writer.append(OpCode::EraseOp, op.getInputOp());
871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
874 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
875 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
876 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
878 llvm_unreachable(
"unsupported element type");
880 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
883 writer.append(OpCode::Finalize);
885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
887 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
888 writer.appendPDLValueKind(arg.
getType());
889 writer.append(curLoopLevel, op.getSuccessor());
891 if (curLoopLevel > maxLoopLevel)
892 maxLoopLevel = curLoopLevel;
893 generate(&op.getRegion(), writer);
896 void Generator::generate(pdl_interp::GetAttributeOp op,
897 ByteCodeWriter &writer) {
898 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
901 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
902 ByteCodeWriter &writer) {
903 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
905 void Generator::generate(pdl_interp::GetDefiningOpOp op,
906 ByteCodeWriter &writer) {
907 writer.append(OpCode::GetDefiningOp, op.getInputOp());
908 writer.appendPDLValue(op.getValue());
910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
911 uint32_t index = op.getIndex();
913 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
915 writer.append(OpCode::GetOperandN, index);
916 writer.append(op.getInputOp(), op.getValue());
918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
919 Value result = op.getValue();
921 writer.append(OpCode::GetOperands,
925 writer.append(getRangeStorageIndex(result));
928 writer.append(result);
930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
931 uint32_t index = op.getIndex();
933 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
935 writer.append(OpCode::GetResultN, index);
936 writer.append(op.getInputOp(), op.getValue());
938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
939 Value result = op.getValue();
941 writer.append(OpCode::GetResults,
945 writer.append(getRangeStorageIndex(result));
948 writer.append(result);
950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
951 Value operations = op.getOperations();
953 writer.append(OpCode::GetUsers, operations, rangeIndex);
954 writer.appendPDLValue(op.getValue());
956 void Generator::generate(pdl_interp::GetValueTypeOp op,
957 ByteCodeWriter &writer) {
958 if (op.getType().isa<pdl::RangeType>()) {
959 Value result = op.getResult();
960 writer.append(OpCode::GetValueRangeTypes, result,
961 getRangeStorageIndex(result), op.getValue());
963 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
967 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
972 op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
973 writer.append(OpCode::RecordMatch, patternIndex,
975 writer.appendPDLValueList(op.getInputs());
977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
978 writer.append(OpCode::ReplaceOp, op.getInputOp());
979 writer.appendPDLValueList(op.getReplValues());
981 void Generator::generate(pdl_interp::SwitchAttributeOp op,
982 ByteCodeWriter &writer) {
983 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
984 op.getCaseValuesAttr(), op.getSuccessors());
986 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987 ByteCodeWriter &writer) {
988 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
989 op.getCaseValuesAttr(), op.getSuccessors());
991 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992 ByteCodeWriter &writer) {
993 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
994 return OperationName(attr.cast<StringAttr>().getValue(), ctx);
996 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
999 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000 ByteCodeWriter &writer) {
1001 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1002 op.getCaseValuesAttr(), op.getSuccessors());
1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1005 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006 op.getSuccessors());
1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1009 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1010 op.getSuccessors());
1018 llvm::StringMap<PDLConstraintFunction> constraintFns,
1019 llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1021 rewriterByteCode, patterns, maxValueMemoryIndex,
1022 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1023 maxLoopLevel, constraintFns, rewriteFns);
1027 for (
auto &it : constraintFns)
1028 constraintFunctions.push_back(std::move(it.second));
1029 for (
auto &it : rewriteFns)
1030 rewriteFunctions.push_back(std::move(it.second));
1036 state.memory.resize(maxValueMemoryIndex,
nullptr);
1037 state.opRangeMemory.resize(maxOpRangeCount);
1038 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1039 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1040 state.loopIndex.resize(maxLoopLevel, 0);
1041 state.currentPatternBenefits.reserve(patterns.size());
1043 state.currentPatternBenefits.push_back(pattern.getBenefit());
1051 class ByteCodeExecutor {
1057 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1059 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1066 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1067 typeRangeMemory(typeRangeMemory),
1068 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1069 valueRangeMemory(valueRangeMemory),
1070 allocatedValueRangeMemory(allocatedValueRangeMemory),
1071 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1072 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1073 constraintFunctions(constraintFunctions),
1074 rewriteFunctions(rewriteFunctions) {}
1087 void executeAreEqual();
1088 void executeAreRangesEqual();
1089 void executeBranch();
1090 void executeCheckOperandCount();
1091 void executeCheckOperationName();
1092 void executeCheckResultCount();
1093 void executeCheckTypes();
1094 void executeContinue();
1097 void executeCreateTypes();
1099 template <
typename T,
typename Range, PDLValue::Kind kind>
1100 void executeExtract();
1101 void executeFinalize();
1102 void executeForEach();
1103 void executeGetAttribute();
1104 void executeGetAttributeType();
1105 void executeGetDefiningOp();
1106 void executeGetOperand(
unsigned index);
1107 void executeGetOperands();
1108 void executeGetResult(
unsigned index);
1109 void executeGetResults();
1110 void executeGetUsers();
1111 void executeGetValueType();
1112 void executeGetValueRangeTypes();
1113 void executeIsNotNull();
1117 void executeSwitchAttribute();
1118 void executeSwitchOperandCount();
1119 void executeSwitchOperationName();
1120 void executeSwitchResultCount();
1121 void executeSwitchType();
1122 void executeSwitchTypes();
1125 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1129 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1130 curCodeIt = resumeCodeIt.back();
1131 resumeCodeIt.pop_back();
1138 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(
ByteCodeField);
1142 return curCodeIt - 1;
1148 template <
typename T = ByteCodeField>
1149 T read(
size_t skipN = 0) {
1151 return readImpl<T>();
1153 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1156 template <
typename ValueT,
typename T>
1159 for (
unsigned i = 0, e = read(); i != e; ++i)
1160 list.push_back(read<ValueT>());
1166 for (
unsigned i = 0, e = read(); i != e; ++i) {
1168 list.push_back(read<Value>());
1171 list.append(values->begin(), values->end());
1177 template <
typename T>
1180 const void *pointer;
1181 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1183 return T::getFromOpaquePointer(pointer);
1187 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1189 void selectJump(
size_t destIndex) {
1190 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1194 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1195 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1197 llvm::dbgs() <<
" * Value: " << value <<
"\n" 1199 llvm::interleaveComma(cases, llvm::dbgs());
1200 llvm::dbgs() <<
"\n";
1205 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1206 if (cmp(*it, value))
1207 return selectJump(
size_t((it - cases.begin()) + 1));
1208 selectJump(
size_t(0));
1212 void storeToMemory(
unsigned index,
const void *value) {
1213 memory[index] =
value;
1217 template <
typename T>
1219 storeToMemory(
unsigned index, T value) {
1220 memory[index] = value.getAsOpaquePointer();
1225 template <
typename T>
1226 const void *readFromMemory() {
1227 size_t index = *curCodeIt++;
1232 index < memory.size())
1233 return memory[index];
1236 return uniquedMemory[index - memory.size()];
1238 template <
typename T>
1240 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1242 template <
typename T>
1246 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1248 template <
typename T>
1250 switch (read<PDLValue::Kind>()) {
1252 return read<Attribute>();
1254 return read<Operation *>();
1256 return read<Type>();
1258 return read<Value>();
1260 return read<TypeRange *>();
1262 return read<ValueRange *>();
1264 llvm_unreachable(
"unhandled PDLValue::Kind");
1266 template <
typename T>
1269 "unexpected ByteCode address size");
1275 template <
typename T>
1277 return *curCodeIt++;
1279 template <
typename T>
1294 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1296 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1315 ByteCodeRewriteResultList(
unsigned maxNumResults)
1323 return allocatedTypeRanges;
1328 return allocatedValueRanges;
1333 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1334 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1337 readList<PDLValue>(args);
1340 llvm::dbgs() <<
" * Arguments: ";
1341 llvm::interleaveComma(args, llvm::dbgs());
1345 selectJump(
succeeded(constraintFn(rewriter, args)));
1348 void ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1349 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1352 readList<PDLValue>(args);
1355 llvm::dbgs() <<
" * Arguments: ";
1356 llvm::interleaveComma(args, llvm::dbgs());
1361 ByteCodeRewriteResultList results(numResults);
1362 rewriteFn(rewriter, results, args);
1364 assert(results.getResults().size() == numResults &&
1365 "native PDL rewrite function returned unexpected number of results");
1368 for (
PDLValue &result : results.getResults()) {
1369 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1373 assert(result.getKind() == read<PDLValue::Kind>() &&
1374 "native PDL rewrite function returned an unexpected type of result");
1380 unsigned rangeIndex = read();
1381 typeRangeMemory[rangeIndex] = *typeRange;
1382 memory[read()] = &typeRangeMemory[rangeIndex];
1385 unsigned rangeIndex = read();
1386 valueRangeMemory[rangeIndex] = *valueRange;
1387 memory[read()] = &valueRangeMemory[rangeIndex];
1394 for (
auto &it : results.getAllocatedTypeRanges())
1395 allocatedTypeRangeMemory.push_back(std::move(it));
1396 for (
auto &it : results.getAllocatedValueRanges())
1397 allocatedValueRangeMemory.push_back(std::move(it));
1400 void ByteCodeExecutor::executeAreEqual() {
1401 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1402 const void *lhs = read<const void *>();
1403 const void *rhs = read<const void *>();
1405 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1406 selectJump(lhs == rhs);
1409 void ByteCodeExecutor::executeAreRangesEqual() {
1410 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1412 const void *lhs = read<const void *>();
1413 const void *rhs = read<const void *>();
1415 switch (valueKind) {
1419 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1420 selectJump(*lhsRange == *rhsRange);
1424 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1425 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1426 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1427 selectJump(*lhsRange == *rhsRange);
1431 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1435 void ByteCodeExecutor::executeBranch() {
1436 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1437 curCodeIt = &code[read<ByteCodeAddr>()];
1440 void ByteCodeExecutor::executeCheckOperandCount() {
1441 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1443 uint32_t expectedCount = read<uint32_t>();
1444 bool compareAtLeast = read();
1446 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n" 1447 <<
" * Expected: " << expectedCount <<
"\n" 1448 <<
" * Comparator: " 1449 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1456 void ByteCodeExecutor::executeCheckOperationName() {
1457 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1461 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n" 1462 <<
" * Expected: \"" << expectedName <<
"\"\n");
1463 selectJump(op->
getName() == expectedName);
1466 void ByteCodeExecutor::executeCheckResultCount() {
1467 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1469 uint32_t expectedCount = read<uint32_t>();
1470 bool compareAtLeast = read();
1472 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n" 1473 <<
" * Expected: " << expectedCount <<
"\n" 1474 <<
" * Comparator: " 1475 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1482 void ByteCodeExecutor::executeCheckTypes() {
1483 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1486 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1488 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1491 void ByteCodeExecutor::executeContinue() {
1493 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n" 1494 <<
" * Level: " << level <<
"\n");
1499 void ByteCodeExecutor::executeCreateTypes() {
1500 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateTypes:\n");
1501 unsigned memIndex = read();
1502 unsigned rangeIndex = read();
1503 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1505 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1508 llvm::OwningArrayRef<Type> storage(typesAttr.size());
1509 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1510 allocatedTypeRangeMemory.emplace_back(std::move(storage));
1514 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1515 memory[memIndex] = &typeRangeMemory[rangeIndex];
1518 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1520 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1522 unsigned memIndex = read();
1525 for (
unsigned i = 0, e = read(); i != e; ++i) {
1526 StringAttr name = read<StringAttr>();
1533 unsigned numResults = read();
1535 InferTypeOpInterface::Concept *inferInterface =
1537 assert(inferInterface &&
1538 "expected operation to provide InferTypeOpInterface");
1541 if (
failed(inferInterface->inferReturnTypes(
1548 for (
unsigned i = 0; i != numResults; ++i) {
1550 state.
types.push_back(read<Type>());
1552 TypeRange *resultTypes = read<TypeRange *>();
1553 state.
types.append(resultTypes->begin(), resultTypes->end());
1559 memory[memIndex] = resultOp;
1562 llvm::dbgs() <<
" * Attributes: " 1564 <<
"\n * Operands: ";
1565 llvm::interleaveComma(state.
operands, llvm::dbgs());
1566 llvm::dbgs() <<
"\n * Result Types: ";
1567 llvm::interleaveComma(state.
types, llvm::dbgs());
1568 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1573 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1576 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1580 template <
typename T,
typename Range, PDLValue::Kind kind>
1581 void ByteCodeExecutor::executeExtract() {
1582 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" << kind <<
":\n");
1583 Range *range = read<Range *>();
1584 unsigned index = read<uint32_t>();
1585 unsigned memIndex = read();
1588 memory[memIndex] =
nullptr;
1592 T result = index < range->
size() ? (*range)[index] : T();
1593 LLVM_DEBUG(llvm::dbgs() <<
" * " << kind <<
"s(" << range->
size() <<
")\n" 1594 <<
" * Index: " << index <<
"\n" 1595 <<
" * Result: " << result <<
"\n");
1596 storeToMemory(memIndex, result);
1599 void ByteCodeExecutor::executeFinalize() {
1600 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1603 void ByteCodeExecutor::executeForEach() {
1604 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1606 unsigned rangeIndex = read();
1607 unsigned memIndex = read();
1608 const void *value =
nullptr;
1610 switch (read<PDLValue::Kind>()) {
1612 unsigned &index = loopIndex[read()];
1614 assert(index <= array.size() &&
"iterated past the end");
1615 if (index < array.size()) {
1616 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1617 value = array[index];
1621 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1623 selectJump(
size_t(0));
1627 llvm_unreachable(
"unexpected `ForEach` value kind");
1631 memory[memIndex] =
value;
1632 pushCodeIt(prevCodeIt);
1635 read<ByteCodeAddr>();
1638 void ByteCodeExecutor::executeGetAttribute() {
1639 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1640 unsigned memIndex = read();
1642 StringAttr attrName = read<StringAttr>();
1645 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n" 1646 <<
" * Attribute: " << attrName <<
"\n" 1647 <<
" * Result: " << attr <<
"\n");
1648 memory[memIndex] = attr.getAsOpaquePointer();
1651 void ByteCodeExecutor::executeGetAttributeType() {
1652 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1653 unsigned memIndex = read();
1656 if (
auto typedAttr = attr.
dyn_cast<TypedAttr>())
1657 type = typedAttr.getType();
1659 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n" 1660 <<
" * Result: " << type <<
"\n");
1661 memory[memIndex] = type.getAsOpaquePointer();
1664 void ByteCodeExecutor::executeGetDefiningOp() {
1665 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1666 unsigned memIndex = read();
1669 Value value = read<Value>();
1672 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1675 if (values && !values->empty()) {
1676 op = values->front().getDefiningOp();
1678 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1681 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1682 memory[memIndex] = op;
1685 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1687 unsigned memIndex = read();
1691 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n" 1692 <<
" * Index: " << index <<
"\n" 1693 <<
" * Result: " << operand <<
"\n");
1700 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1708 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1712 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1713 LLVM_DEBUG(llvm::dbgs()
1714 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1717 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1721 unsigned startIndex =
1722 std::accumulate(segments.begin(), segments.begin() + index, 0);
1723 values = values.slice(startIndex, *std::next(segments.begin(), index));
1725 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", " 1726 << *std::next(segments.begin(), index) <<
"]\n");
1732 }
else if (values.size() >= index) {
1733 LLVM_DEBUG(llvm::dbgs()
1734 <<
" * Treating values as trailing variadic range\n");
1735 values = values.drop_front(index);
1744 valueRangeMemory[rangeIndex] = values;
1745 return &valueRangeMemory[rangeIndex];
1749 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1752 void ByteCodeExecutor::executeGetOperands() {
1753 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1754 unsigned index = read<uint32_t>();
1758 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1759 op->getOperands(), op, index, rangeIndex,
"operand_segment_sizes",
1762 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1763 memory[read()] = result;
1766 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1768 unsigned memIndex = read();
1772 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n" 1773 <<
" * Index: " << index <<
"\n" 1774 <<
" * Result: " << result <<
"\n");
1778 void ByteCodeExecutor::executeGetResults() {
1779 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1780 unsigned index = read<uint32_t>();
1784 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1785 op->getResults(), op, index, rangeIndex,
"result_segment_sizes",
1788 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1789 memory[read()] = result;
1792 void ByteCodeExecutor::executeGetUsers() {
1793 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1794 unsigned memIndex = read();
1795 unsigned rangeIndex = read();
1797 memory[memIndex] = ⦥
1802 Value value = read<Value>();
1805 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1816 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1817 llvm::interleaveComma(*values, llvm::dbgs());
1818 llvm::dbgs() <<
"\n";
1823 for (
Value value : *values)
1829 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1832 void ByteCodeExecutor::executeGetValueType() {
1833 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1834 unsigned memIndex = read();
1835 Value value = read<Value>();
1838 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n" 1839 <<
" * Result: " << type <<
"\n");
1840 memory[memIndex] = type.getAsOpaquePointer();
1843 void ByteCodeExecutor::executeGetValueRangeTypes() {
1844 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
1845 unsigned memIndex = read();
1846 unsigned rangeIndex = read();
1849 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
1850 memory[memIndex] =
nullptr;
1855 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1856 llvm::interleaveComma(*values, llvm::dbgs());
1857 llvm::dbgs() <<
"\n * Result: ";
1858 llvm::interleaveComma(values->
getType(), llvm::dbgs());
1859 llvm::dbgs() <<
"\n";
1861 typeRangeMemory[rangeIndex] = values->
getType();
1862 memory[memIndex] = &typeRangeMemory[rangeIndex];
1865 void ByteCodeExecutor::executeIsNotNull() {
1866 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
1867 const void *value = read<const void *>();
1869 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1870 selectJump(value !=
nullptr);
1873 void ByteCodeExecutor::executeRecordMatch(
1876 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
1877 unsigned patternIndex = read();
1884 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
1893 unsigned numMatchLocs = read();
1895 matchLocs.reserve(numMatchLocs);
1896 for (
unsigned i = 0; i != numMatchLocs; ++i)
1897 matchLocs.push_back(read<Operation *>()->getLoc());
1900 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n" 1901 <<
" * Location: " << matchLoc <<
"\n");
1902 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1908 unsigned numInputs = read();
1909 match.
values.reserve(numInputs);
1912 for (
unsigned i = 0; i < numInputs; ++i) {
1913 switch (read<PDLValue::Kind>()) {
1923 match.
values.push_back(read<const void *>());
1931 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
1934 readValueList(args);
1937 llvm::dbgs() <<
" * Operation: " << *op <<
"\n" 1939 llvm::interleaveComma(args, llvm::dbgs());
1940 llvm::dbgs() <<
"\n";
1945 void ByteCodeExecutor::executeSwitchAttribute() {
1946 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
1948 ArrayAttr cases = read<ArrayAttr>();
1949 handleSwitch(value, cases);
1952 void ByteCodeExecutor::executeSwitchOperandCount() {
1953 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
1955 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1957 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1961 void ByteCodeExecutor::executeSwitchOperationName() {
1962 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
1964 size_t caseCount = read();
1971 llvm::dbgs() <<
" * Value: " << value <<
"\n" 1973 llvm::interleaveComma(
1974 llvm::map_range(llvm::seq<size_t>(0, caseCount),
1975 [&](
size_t) {
return read<OperationName>(); }),
1977 llvm::dbgs() <<
"\n";
1978 curCodeIt = prevCodeIt;
1982 for (
size_t i = 0; i != caseCount; ++i) {
1983 if (read<OperationName>() ==
value) {
1984 curCodeIt += (caseCount - i - 1);
1985 return selectJump(i + 1);
1988 selectJump(
size_t(0));
1991 void ByteCodeExecutor::executeSwitchResultCount() {
1992 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
1994 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1996 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2000 void ByteCodeExecutor::executeSwitchType() {
2001 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2002 Type value = read<Type>();
2003 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2004 handleSwitch(value, cases);
2007 void ByteCodeExecutor::executeSwitchTypes() {
2008 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2010 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2012 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2013 return selectJump(
size_t(0));
2015 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2016 return value == caseValue.getAsValueRange<TypeAttr>();
2020 void ByteCodeExecutor::execute(
2026 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2030 case ApplyConstraint:
2031 executeApplyConstraint(rewriter);
2034 executeApplyRewrite(rewriter);
2039 case AreRangesEqual:
2040 executeAreRangesEqual();
2045 case CheckOperandCount:
2046 executeCheckOperandCount();
2048 case CheckOperationName:
2049 executeCheckOperationName();
2051 case CheckResultCount:
2052 executeCheckResultCount();
2055 executeCheckTypes();
2060 case CreateOperation:
2061 executeCreateOperation(rewriter, *mainRewriteLoc);
2064 executeCreateTypes();
2067 executeEraseOp(rewriter);
2070 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2073 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2076 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2080 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2086 executeGetAttribute();
2088 case GetAttributeType:
2089 executeGetAttributeType();
2092 executeGetDefiningOp();
2098 unsigned index = opCode - GetOperand0;
2099 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2100 executeGetOperand(index);
2104 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2105 executeGetOperand(read<uint32_t>());
2108 executeGetOperands();
2114 unsigned index = opCode - GetResult0;
2115 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2116 executeGetResult(index);
2120 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2121 executeGetResult(read<uint32_t>());
2124 executeGetResults();
2130 executeGetValueType();
2132 case GetValueRangeTypes:
2133 executeGetValueRangeTypes();
2140 "expected matches to be provided when executing the matcher");
2141 executeRecordMatch(rewriter, *matches);
2144 executeReplaceOp(rewriter);
2146 case SwitchAttribute:
2147 executeSwitchAttribute();
2149 case SwitchOperandCount:
2150 executeSwitchOperandCount();
2152 case SwitchOperationName:
2153 executeSwitchOperationName();
2155 case SwitchResultCount:
2156 executeSwitchResultCount();
2159 executeSwitchType();
2162 executeSwitchTypes();
2165 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2175 state.memory[0] = op;
2178 ByteCodeExecutor executor(
2179 matcherByteCode.data(), state.memory, state.opRangeMemory,
2180 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2181 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2182 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2183 constraintFunctions, rewriteFunctions);
2184 executor.execute(rewriter, &matches);
2187 std::stable_sort(matches.begin(), matches.end(),
2189 return lhs.
benefit > rhs.benefit;
2200 ByteCodeExecutor executor(
2202 state.opRangeMemory, state.typeRangeMemory,
2203 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2204 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2205 rewriterByteCode, state.currentPatternBenefits, patterns,
2206 constraintFunctions, rewriteFunctions);
2207 executor.execute(rewriter,
nullptr, match.
location);
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation is a basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This is a value defined by a result of an operation.
Block represents an ordered list of Operations.
This class represents liveness information on block level.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
unsigned getNumOperands()
llvm::OwningArrayRef< Operation * > OwningOpRange
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
This class implements the result iterators for the Operation class.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
PDLByteCode(ModuleOp module, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
user_range getUsers() const
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
static constexpr const bool value
SmallVector< Value, 4 > operands
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
std::function< void(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
All of the data pertaining to a specific pattern within the bytecode.
uint16_t ByteCodeField
Use generic bytecode types.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr)
void rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
SmallVector< ValueRange, 0 > valueRangeValues
Storage type of byte-code interpreter values.
Attributes are known-constant values of operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
user_iterator user_begin() const
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class provides an abstraction over the various different ranges of value types.
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, None otherwise.
Location getLoc()
The source location the operation was defined or derived from.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
This class acts as a special tag that makes the desire to match "any" operation type explicit...
BlockArgListType getArguments()
Represents an analysis for computing liveness information from a given top-level operation.
This class represents an argument of a Block.
Kind
The underlying kind of a PDL value.
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
user_iterator user_end() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
MLIRContext * getContext() const
Get the context held by this operation state.
bool isImpossibleToMatch() const
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Type getType() const
Return the type of this value.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
type_range getType() const
Location location
The location of operations to be replaced.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
The class represents a list of PDL results, returned by a native rewrite method.
const PDLByteCodePattern * pattern
The originating pattern that was matched.
This class implements the operand iterators for the Operation class.
unsigned getNumResults()
Return the number of results held by this operation.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
This class contains the mutable state of a bytecode instance.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches...
PatternBenefit benefit
The current benefit of the pattern that was matched.
OperationName getName()
The name of an operation is the key identifier for it.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
result_range getResults()
This class provides an abstraction over the different types of ranges over Values.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ByteCodeAddr getRewriterAddr() const
Return the bytecode address of the rewriter for this pattern.
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
SmallVector< Type, 4 > types
Types of the results of this operation.