19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InterleavedRange.h"
30 #define DEBUG_TYPE "pdl-bytecode"
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
47 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
56 benefit, ctx, generatedOps);
68 currentPatternBenefits[patternIndex] = benefit;
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
84 enum OpCode : ByteCodeField {
106 CreateConstantTypeRange,
110 CreateDynamicTypeRange,
112 CreateDynamicValueRange,
187 struct ByteCodeLiveRange;
188 struct ByteCodeWriter;
191 template <
typename T,
typename... Args>
192 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
197 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
223 void generate(ModuleOp module);
226 ByteCodeField &getMemIndex(
Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
233 ByteCodeField &getRangeStorageIndex(
Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
241 template <
typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
244 const void *opaqueVal = val.getAsOpaquePointer();
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
261 void generate(
Region *region, ByteCodeWriter &writer);
262 void generate(
Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
325 ByteCodeField curLoopLevel = 0;
334 std::vector<const void *> &uniquedData;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
349 struct ByteCodeWriter {
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
358 void append(ByteCodeAddr field) {
359 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
369 void append(
Block *successor) {
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
379 for (
Block *successor : successors)
385 bytecode.push_back(values.size());
386 for (
Value value : values)
387 appendPDLValue(value);
391 void appendPDLValue(
Value value) {
392 appendPDLValueKind(value);
397 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
400 void appendPDLValueKind(
Type type) {
403 .Case<pdl::AttributeType>(
404 [](
Type) {
return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](
Type) {
return PDLValue::Kind::Operation; })
407 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
412 .Case<pdl::TypeType>([](
Type) {
return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](
Type) {
return PDLValue::Kind::Value; });
414 bytecode.push_back(
static_cast<ByteCodeField
>(
kind));
419 template <
typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
423 bytecode.push_back(
generator.getMemIndex(value));
427 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
430 bytecode.push_back(llvm::size(range));
431 for (
auto it : range)
436 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
439 append(field2, fields...);
443 template <
typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
465 struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
472 void unionWith(
const ByteCodeLiveRange &rhs) {
473 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
475 liveness->insert(it.start(), it.stop(), 0);
479 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
493 std::optional<unsigned> opRangeIndex;
496 std::optional<unsigned> typeRangeIndex;
499 std::optional<unsigned> valueRangeIndex;
503 void Generator::generate(ModuleOp module) {
504 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
512 allocateMemoryIndices(matcherFunc, rewriterModule);
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
516 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (
Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
529 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (
unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
536 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
540 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](
Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
554 processRewriterValue(arg);
557 processRewriterValue(result);
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
581 opToFirstIndex.try_emplace(op, index++);
583 for (
Block &block : region.getBlocks())
586 opToLastIndex.try_emplace(op, index++);
591 ByteCodeLiveRange::Allocator allocator;
596 valueToMemIndex[rootOpArg] = 0;
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](
Block *block) {
602 assert(info &&
"expected liveness info for block");
606 if (value == rootOpArg)
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
617 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
629 for (
Value liveIn : info->
in()) {
634 if (liveIn.getParentRegion() == block->
getParent())
651 std::vector<ByteCodeLiveRange> allocatedIndices;
655 ByteCodeField numIndices = 1;
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
660 for (
auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
665 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 }
else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 }
else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 }
else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 }
else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
706 memIndex = allocatedIndices.size();
712 LDBG() <<
"Allocated " << allocatedIndices.size() <<
" indices "
713 <<
"(down from initial " << valueDefRanges.size() <<
").";
715 "Ran out of memory for allocated indices");
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
728 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (
Block *block : rpot) {
732 blockToAddr.try_emplace(block, matcherByteCode.size());
734 generate(&op, writer);
738 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
739 LDBG() <<
"Generating bytecode for operation: " << op->
getName();
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->
getLoc());
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](
auto interpOp) { this->generate(interpOp, writer); })
768 llvm_unreachable(
"unknown `pdl_interp` operation");
772 void Generator::generate(pdl_interp::ApplyConstraintOp op,
773 ByteCodeWriter &writer) {
777 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
778 writer.appendPDLValueList(op.getArgs());
779 writer.append(ByteCodeField(op.getIsNegated()));
781 writer.append(ByteCodeField(results.size()));
782 for (
Value result : results) {
786 writer.appendPDLValueKind(result);
789 if (isa<pdl::RangeType>(result.getType()))
790 writer.append(getRangeStorageIndex(result));
791 writer.append(result);
793 writer.append(op.getSuccessors());
795 void Generator::generate(pdl_interp::ApplyRewriteOp op,
796 ByteCodeWriter &writer) {
797 assert(externalRewriterToMemIndex.count(op.getName()) &&
798 "expected index for rewrite function");
799 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
800 writer.appendPDLValueList(op.getArgs());
803 writer.append(ByteCodeField(results.size()));
804 for (
Value result : results) {
807 writer.appendPDLValueKind(result);
810 if (isa<pdl::RangeType>(result.getType()))
811 writer.append(getRangeStorageIndex(result));
812 writer.append(result);
815 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
816 Value lhs = op.getLhs();
817 if (isa<pdl::RangeType>(lhs.
getType())) {
818 writer.append(OpCode::AreRangesEqual);
819 writer.appendPDLValueKind(lhs);
820 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
824 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
826 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
829 void Generator::generate(pdl_interp::CheckAttributeOp op,
830 ByteCodeWriter &writer) {
831 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
834 void Generator::generate(pdl_interp::CheckOperandCountOp op,
835 ByteCodeWriter &writer) {
836 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
837 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
840 void Generator::generate(pdl_interp::CheckOperationNameOp op,
841 ByteCodeWriter &writer) {
842 writer.append(OpCode::CheckOperationName, op.getInputOp(),
845 void Generator::generate(pdl_interp::CheckResultCountOp op,
846 ByteCodeWriter &writer) {
847 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
848 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
851 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
852 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
855 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
856 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
859 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
860 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
861 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
863 void Generator::generate(pdl_interp::CreateAttributeOp op,
864 ByteCodeWriter &writer) {
866 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
868 void Generator::generate(pdl_interp::CreateOperationOp op,
869 ByteCodeWriter &writer) {
870 writer.append(OpCode::CreateOperation, op.getResultOp(),
872 writer.appendPDLValueList(op.getInputOperands());
876 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
877 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
878 writer.append(std::get<0>(it), std::get<1>(it));
882 if (op.getInferredResultTypes())
885 writer.appendPDLValueList(op.getInputResultTypes());
887 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
891 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
892 .Case([&](pdl::ValueType) {
893 writer.append(OpCode::CreateDynamicValueRange);
896 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
897 writer.appendPDLValueList(op->getOperands());
899 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
901 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
903 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
904 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
905 getRangeStorageIndex(op.getResult()), op.getValue());
907 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
908 writer.append(OpCode::EraseOp, op.getInputOp());
910 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
913 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
914 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
915 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
916 .Default([](
Type) -> OpCode {
917 llvm_unreachable(
"unsupported element type");
919 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
921 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
922 writer.append(OpCode::Finalize);
924 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
926 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
927 writer.appendPDLValueKind(arg.
getType());
928 writer.append(curLoopLevel, op.getSuccessor());
930 if (curLoopLevel > maxLoopLevel)
931 maxLoopLevel = curLoopLevel;
932 generate(&op.getRegion(), writer);
935 void Generator::generate(pdl_interp::GetAttributeOp op,
936 ByteCodeWriter &writer) {
937 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
940 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
944 void Generator::generate(pdl_interp::GetDefiningOpOp op,
945 ByteCodeWriter &writer) {
946 writer.append(OpCode::GetDefiningOp, op.getInputOp());
947 writer.appendPDLValue(op.getValue());
949 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
950 uint32_t index = op.getIndex();
952 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
954 writer.append(OpCode::GetOperandN, index);
955 writer.append(op.getInputOp(), op.getValue());
957 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
958 Value result = op.getValue();
959 std::optional<uint32_t> index = op.getIndex();
960 writer.append(OpCode::GetOperands,
963 if (isa<pdl::RangeType>(result.
getType()))
964 writer.append(getRangeStorageIndex(result));
967 writer.append(result);
969 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
970 uint32_t index = op.getIndex();
972 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
974 writer.append(OpCode::GetResultN, index);
975 writer.append(op.getInputOp(), op.getValue());
977 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
978 Value result = op.getValue();
979 std::optional<uint32_t> index = op.getIndex();
980 writer.append(OpCode::GetResults,
983 if (isa<pdl::RangeType>(result.
getType()))
984 writer.append(getRangeStorageIndex(result));
987 writer.append(result);
989 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
990 Value operations = op.getOperations();
991 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
992 writer.append(OpCode::GetUsers, operations, rangeIndex);
993 writer.appendPDLValue(op.getValue());
995 void Generator::generate(pdl_interp::GetValueTypeOp op,
996 ByteCodeWriter &writer) {
997 if (isa<pdl::RangeType>(op.getType())) {
998 Value result = op.getResult();
999 writer.append(OpCode::GetValueRangeTypes, result,
1000 getRangeStorageIndex(result), op.getValue());
1002 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1005 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1006 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1008 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1009 ByteCodeField patternIndex =
patterns.size();
1010 patterns.emplace_back(PDLByteCodePattern::create(
1011 op, configMap.lookup(op),
1012 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1013 writer.append(OpCode::RecordMatch, patternIndex,
1015 writer.appendPDLValueList(op.getInputs());
1017 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1018 writer.append(OpCode::ReplaceOp, op.getInputOp());
1019 writer.appendPDLValueList(op.getReplValues());
1021 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1022 ByteCodeWriter &writer) {
1023 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1024 op.getCaseValuesAttr(), op.getSuccessors());
1026 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1027 ByteCodeWriter &writer) {
1028 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1029 op.getCaseValuesAttr(), op.getSuccessors());
1031 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1032 ByteCodeWriter &writer) {
1033 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1034 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1036 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1037 op.getSuccessors());
1039 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1040 ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1042 op.getCaseValuesAttr(), op.getSuccessors());
1044 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1048 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1049 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1050 op.getSuccessors());
1057 PDLByteCode::PDLByteCode(
1058 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1060 llvm::StringMap<PDLConstraintFunction> constraintFns,
1061 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1062 : configs(std::move(configs)) {
1063 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1064 rewriterByteCode,
patterns, maxValueMemoryIndex,
1065 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1066 maxLoopLevel, constraintFns, rewriteFns, configMap);
1070 for (
auto &it : constraintFns)
1071 constraintFunctions.push_back(std::move(it.second));
1072 for (
auto &it : rewriteFns)
1073 rewriteFunctions.push_back(std::move(it.second));
1079 state.memory.resize(maxValueMemoryIndex,
nullptr);
1080 state.opRangeMemory.resize(maxOpRangeCount);
1081 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1082 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1083 state.loopIndex.resize(maxLoopLevel, 0);
1084 state.currentPatternBenefits.reserve(
patterns.size());
1086 state.currentPatternBenefits.push_back(pattern.getBenefit());
1097 class ByteCodeRewriteResultList :
public PDLResultList {
1099 ByteCodeRewriteResultList(
unsigned maxNumResults)
1100 : PDLResultList(maxNumResults) {}
1107 return allocatedTypeRanges;
1112 return allocatedValueRanges;
1117 class ByteCodeExecutor {
1123 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1125 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1132 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1133 typeRangeMemory(typeRangeMemory),
1134 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1135 valueRangeMemory(valueRangeMemory),
1136 allocatedValueRangeMemory(allocatedValueRangeMemory),
1137 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1139 constraintFunctions(constraintFunctions),
1140 rewriteFunctions(rewriteFunctions) {}
1148 std::optional<Location> mainRewriteLoc = {});
1154 void executeAreEqual();
1155 void executeAreRangesEqual();
1156 void executeBranch();
1157 void executeCheckOperandCount();
1158 void executeCheckOperationName();
1159 void executeCheckResultCount();
1160 void executeCheckTypes();
1161 void executeContinue();
1162 void executeCreateConstantTypeRange();
1165 template <
typename T>
1166 void executeDynamicCreateRange(StringRef type);
1168 template <
typename T,
typename Range, PDLValue::Kind kind>
1169 void executeExtract();
1170 void executeFinalize();
1171 void executeForEach();
1172 void executeGetAttribute();
1173 void executeGetAttributeType();
1174 void executeGetDefiningOp();
1175 void executeGetOperand(
unsigned index);
1176 void executeGetOperands();
1177 void executeGetResult(
unsigned index);
1178 void executeGetResults();
1179 void executeGetUsers();
1180 void executeGetValueType();
1181 void executeGetValueRangeTypes();
1182 void executeIsNotNull();
1186 void executeSwitchAttribute();
1187 void executeSwitchOperandCount();
1188 void executeSwitchOperationName();
1189 void executeSwitchResultCount();
1190 void executeSwitchType();
1191 void executeSwitchTypes();
1192 void processNativeFunResults(ByteCodeRewriteResultList &results,
1193 unsigned numResults,
1194 LogicalResult &rewriteResult);
1197 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1201 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1202 curCodeIt = resumeCodeIt.pop_back_val();
1206 const ByteCodeField *getPrevCodeIt()
const {
1209 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1213 return curCodeIt - 1;
1219 template <
typename T = ByteCodeField>
1220 T read(
size_t skipN = 0) {
1222 return readImpl<T>();
1224 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1227 template <
typename ValueT,
typename T>
1230 for (
unsigned i = 0, e = read(); i != e; ++i)
1231 list.push_back(read<ValueT>());
1237 for (
unsigned i = 0, e = read(); i != e; ++i) {
1238 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1239 list.push_back(read<Type>());
1241 TypeRange *values = read<TypeRange *>();
1242 list.append(values->begin(), values->end());
1247 for (
unsigned i = 0, e = read(); i != e; ++i) {
1248 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1249 list.push_back(read<Value>());
1252 list.append(values->begin(), values->end());
1258 template <
typename T>
1259 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1261 const void *pointer;
1262 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1263 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1264 return T::getFromOpaquePointer(pointer);
1267 void skip(
size_t skipN) { curCodeIt += skipN; }
1270 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1272 void selectJump(
size_t destIndex) {
1273 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1277 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1278 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1279 LDBG() <<
"Switch operation:\n * Value: " << value
1280 <<
"\n * Cases: " << llvm::interleaved(cases);
1284 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1285 if (cmp(*it, value))
1286 return selectJump(
size_t((it - cases.begin()) + 1));
1287 selectJump(
size_t(0));
1291 void storeToMemory(
unsigned index,
const void *value) {
1292 memory[index] = value;
1296 template <
typename T>
1297 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1298 storeToMemory(
unsigned index, T value) {
1299 memory[index] = value.getAsOpaquePointer();
1304 template <
typename T>
1305 const void *readFromMemory() {
1306 size_t index = *curCodeIt++;
1311 index < memory.size())
1312 return memory[index];
1315 return uniquedMemory[index - memory.size()];
1317 template <
typename T>
1318 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1319 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1321 template <
typename T>
1322 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1325 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1327 template <
typename T>
1328 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1329 switch (read<PDLValue::Kind>()) {
1330 case PDLValue::Kind::Attribute:
1331 return read<Attribute>();
1332 case PDLValue::Kind::Operation:
1333 return read<Operation *>();
1334 case PDLValue::Kind::Type:
1335 return read<Type>();
1336 case PDLValue::Kind::Value:
1337 return read<Value>();
1338 case PDLValue::Kind::TypeRange:
1339 return read<TypeRange *>();
1340 case PDLValue::Kind::ValueRange:
1341 return read<ValueRange *>();
1343 llvm_unreachable(
"unhandled PDLValue::Kind");
1345 template <
typename T>
1346 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1347 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1348 "unexpected ByteCode address size");
1349 ByteCodeAddr result;
1350 std::memcpy(&result, curCodeIt,
sizeof(ByteCodeAddr));
1354 template <
typename T>
1355 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1356 return *curCodeIt++;
1358 template <
typename T>
1359 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1365 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1366 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1367 unsigned rangeIndex) {
1369 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1371 if (range.empty()) {
1372 rangeMemory[rangeIndex] = {};
1375 llvm::OwningArrayRef<T> storage(llvm::size(range));
1380 allocatedRangeMemory.emplace_back(std::move(storage));
1381 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1383 memory[memIndex] = &rangeMemory[rangeIndex];
1387 if constexpr (std::is_same_v<T, Type>) {
1388 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1389 }
else if constexpr (std::is_same_v<T, Value>) {
1390 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1392 llvm_unreachable(
"unhandled range type");
1397 const ByteCodeField *curCodeIt;
1406 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1408 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1423 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1424 LDBG() <<
"Executing ApplyConstraint:";
1425 ByteCodeField fun_idx = read();
1427 readList<PDLValue>(args);
1429 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1431 ByteCodeField isNegated = read();
1432 LDBG() <<
" * isNegated: " << isNegated;
1434 ByteCodeField numResults = read();
1435 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1436 ByteCodeRewriteResultList results(numResults);
1437 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1439 if (succeeded(rewriteResult)) {
1440 LDBG() <<
" * Constraint succeeded, results: "
1441 << llvm::interleaved(constraintResults);
1443 LDBG() <<
" * Constraint failed";
1445 assert((
failed(rewriteResult) || constraintResults.size() == numResults) &&
1446 "native PDL rewrite function succeeded but returned "
1447 "unexpected number of results");
1448 processNativeFunResults(results, numResults, rewriteResult);
1451 selectJump(isNegated != succeeded(rewriteResult));
1454 LogicalResult ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1455 LDBG() <<
"Executing ApplyRewrite:";
1456 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1458 readList<PDLValue>(args);
1460 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1463 ByteCodeField numResults = read();
1464 ByteCodeRewriteResultList results(numResults);
1465 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1467 assert(results.getResults().size() == numResults &&
1468 "native PDL rewrite function returned unexpected number of results");
1470 processNativeFunResults(results, numResults, rewriteResult);
1472 if (
failed(rewriteResult)) {
1473 LDBG() <<
" - Failed";
1479 void ByteCodeExecutor::processNativeFunResults(
1480 ByteCodeRewriteResultList &results,
unsigned numResults,
1481 LogicalResult &rewriteResult) {
1482 if (
failed(rewriteResult)) {
1485 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1487 if (resultKind == PDLValue::Kind::TypeRange ||
1488 resultKind == PDLValue::Kind::ValueRange) {
1498 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1501 PDLValue result = results.getResults()[resultIdx];
1502 LDBG() <<
" * Result: " << result;
1503 assert(result.getKind() == resultKind &&
1504 "native PDL rewrite function returned an unexpected type of "
1508 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1509 unsigned rangeIndex = read();
1510 typeRangeMemory[rangeIndex] = *typeRange;
1511 memory[read()] = &typeRangeMemory[rangeIndex];
1512 }
else if (std::optional<ValueRange> valueRange =
1514 unsigned rangeIndex = read();
1515 valueRangeMemory[rangeIndex] = *valueRange;
1516 memory[read()] = &valueRangeMemory[rangeIndex];
1518 memory[read()] = result.getAsOpaquePointer();
1523 for (
auto &it : results.getAllocatedTypeRanges())
1524 allocatedTypeRangeMemory.push_back(std::move(it));
1525 for (
auto &it : results.getAllocatedValueRanges())
1526 allocatedValueRangeMemory.push_back(std::move(it));
1529 void ByteCodeExecutor::executeAreEqual() {
1530 LDBG() <<
"Executing AreEqual:";
1531 const void *lhs = read<const void *>();
1532 const void *rhs = read<const void *>();
1534 LDBG() <<
" * " << lhs <<
" == " << rhs;
1535 selectJump(lhs == rhs);
1538 void ByteCodeExecutor::executeAreRangesEqual() {
1539 LDBG() <<
"Executing AreRangesEqual:";
1541 const void *lhs = read<const void *>();
1542 const void *rhs = read<const void *>();
1544 switch (valueKind) {
1545 case PDLValue::Kind::TypeRange: {
1548 LDBG() <<
" * " << lhs <<
" == " << rhs;
1549 selectJump(*lhsRange == *rhsRange);
1552 case PDLValue::Kind::ValueRange: {
1553 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1554 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1555 LDBG() <<
" * " << lhs <<
" == " << rhs;
1556 selectJump(*lhsRange == *rhsRange);
1560 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1564 void ByteCodeExecutor::executeBranch() {
1565 LDBG() <<
"Executing Branch";
1566 curCodeIt = &code[read<ByteCodeAddr>()];
1569 void ByteCodeExecutor::executeCheckOperandCount() {
1570 LDBG() <<
"Executing CheckOperandCount:";
1572 uint32_t expectedCount = read<uint32_t>();
1573 bool compareAtLeast = read();
1576 <<
"\n * Expected: " << expectedCount
1577 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1584 void ByteCodeExecutor::executeCheckOperationName() {
1585 LDBG() <<
"Executing CheckOperationName:";
1589 LDBG() <<
" * Found: \"" << op->
getName() <<
"\"\n * Expected: \""
1590 << expectedName <<
"\"";
1591 selectJump(op->
getName() == expectedName);
1594 void ByteCodeExecutor::executeCheckResultCount() {
1595 LDBG() <<
"Executing CheckResultCount:";
1597 uint32_t expectedCount = read<uint32_t>();
1598 bool compareAtLeast = read();
1601 <<
"\n * Expected: " << expectedCount
1602 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1609 void ByteCodeExecutor::executeCheckTypes() {
1610 LDBG() <<
"Executing AreEqual:";
1613 LDBG() <<
" * " << lhs <<
" == " << rhs;
1615 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1618 void ByteCodeExecutor::executeContinue() {
1619 ByteCodeField level = read();
1620 LDBG() <<
"Executing Continue\n * Level: " << level;
1625 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1626 LDBG() <<
"Executing CreateConstantTypeRange:";
1627 unsigned memIndex = read();
1628 unsigned rangeIndex = read();
1629 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1631 LDBG() <<
" * Types: " << typesAttr;
1632 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1636 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1638 LDBG() <<
"Executing CreateOperation:";
1640 unsigned memIndex = read();
1642 readList(state.operands);
1643 for (
unsigned i = 0, e = read(); i != e; ++i) {
1644 StringAttr name = read<StringAttr>();
1646 state.addAttribute(name, attr);
1651 unsigned numResults = read();
1653 InferTypeOpInterface::Concept *inferInterface =
1654 state.name.getInterface<InferTypeOpInterface>();
1655 assert(inferInterface &&
1656 "expected operation to provide InferTypeOpInterface");
1659 if (
failed(inferInterface->inferReturnTypes(
1660 state.getContext(), state.location, state.operands,
1661 state.attributes.getDictionary(state.getContext()),
1662 state.getRawProperties(), state.regions, state.types)))
1666 for (
unsigned i = 0; i != numResults; ++i) {
1667 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1668 state.types.push_back(read<Type>());
1670 TypeRange *resultTypes = read<TypeRange *>();
1671 state.types.append(resultTypes->begin(), resultTypes->end());
1677 memory[memIndex] = resultOp;
1679 LDBG() <<
" * Attributes: "
1680 << state.attributes.getDictionary(state.getContext())
1681 <<
"\n * Operands: " << llvm::interleaved(state.operands)
1682 <<
"\n * Result Types: " << llvm::interleaved(state.types)
1683 <<
"\n * Result: " << *resultOp;
1686 template <
typename T>
1687 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1688 LDBG() <<
"Executing CreateDynamic" << type <<
"Range:";
1689 unsigned memIndex = read();
1690 unsigned rangeIndex = read();
1694 LDBG() <<
" * " << type <<
"s: " << llvm::interleaved(values);
1696 assignRangeToMemory(values, memIndex, rangeIndex);
1700 LDBG() <<
"Executing EraseOp:";
1703 LDBG() <<
" * Operation: " << *op;
1707 template <
typename T,
typename Range, PDLValue::Kind kind>
1708 void ByteCodeExecutor::executeExtract() {
1709 LDBG() <<
"Executing Extract" <<
kind <<
":";
1710 Range *range = read<Range *>();
1711 unsigned index = read<uint32_t>();
1712 unsigned memIndex = read();
1715 memory[memIndex] =
nullptr;
1719 T result = index < range->
size() ? (*range)[index] : T();
1720 LDBG() <<
" * " <<
kind <<
"s(" << range->
size() <<
")";
1721 LDBG() <<
" * Index: " << index;
1722 LDBG() <<
" * Result: " << result;
1723 storeToMemory(memIndex, result);
1726 void ByteCodeExecutor::executeFinalize() { LDBG() <<
"Executing Finalize"; }
1728 void ByteCodeExecutor::executeForEach() {
1729 LDBG() <<
"Executing ForEach:";
1730 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1731 unsigned rangeIndex = read();
1732 unsigned memIndex = read();
1733 const void *value =
nullptr;
1735 switch (read<PDLValue::Kind>()) {
1736 case PDLValue::Kind::Operation: {
1737 unsigned &index = loopIndex[read()];
1739 assert(index <= array.size() &&
"iterated past the end");
1740 if (index < array.size()) {
1741 LDBG() <<
" * Result: " << array[index];
1742 value = array[index];
1746 LDBG() <<
" * Done";
1748 selectJump(
size_t(0));
1752 llvm_unreachable(
"unexpected `ForEach` value kind");
1756 memory[memIndex] = value;
1757 pushCodeIt(prevCodeIt);
1760 read<ByteCodeAddr>();
1763 void ByteCodeExecutor::executeGetAttribute() {
1764 LDBG() <<
"Executing GetAttribute:";
1765 unsigned memIndex = read();
1767 StringAttr attrName = read<StringAttr>();
1770 LDBG() <<
" * Operation: " << *op <<
"\n * Attribute: " << attrName
1771 <<
"\n * Result: " << attr;
1775 void ByteCodeExecutor::executeGetAttributeType() {
1776 LDBG() <<
"Executing GetAttributeType:";
1777 unsigned memIndex = read();
1780 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1781 type = typedAttr.getType();
1783 LDBG() <<
" * Attribute: " << attr <<
"\n * Result: " << type;
1787 void ByteCodeExecutor::executeGetDefiningOp() {
1788 LDBG() <<
"Executing GetDefiningOp:";
1789 unsigned memIndex = read();
1791 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1792 Value value = read<Value>();
1795 LDBG() <<
" * Value: " << value;
1798 if (values && !values->empty()) {
1799 op = values->front().getDefiningOp();
1801 LDBG() <<
" * Values: " << values;
1804 LDBG() <<
" * Result: " << op;
1805 memory[memIndex] = op;
1808 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1810 unsigned memIndex = read();
1814 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1815 <<
"\n * Result: " << operand;
1822 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1825 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1830 LDBG() <<
" * Getting all values";
1834 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1835 LDBG() <<
" * Extracting values from `" << attrSizedSegments <<
"`";
1838 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1842 unsigned startIndex =
1843 std::accumulate(segments.begin(), segments.begin() + index, 0);
1844 values = values.slice(startIndex, *std::next(segments.begin(), index));
1846 LDBG() <<
" * Extracting range[" << startIndex <<
", "
1847 << *std::next(segments.begin(), index) <<
"]";
1853 }
else if (values.size() >= index) {
1854 LDBG() <<
" * Treating values as trailing variadic range";
1855 values = values.drop_front(index);
1864 valueRangeMemory[rangeIndex] = values;
1865 return &valueRangeMemory[rangeIndex];
1869 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1872 void ByteCodeExecutor::executeGetOperands() {
1873 LDBG() <<
"Executing GetOperands:";
1874 unsigned index = read<uint32_t>();
1876 ByteCodeField rangeIndex = read();
1878 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1879 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1882 LDBG() <<
" * Invalid operand range";
1883 memory[read()] = result;
1886 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1888 unsigned memIndex = read();
1892 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1893 <<
"\n * Result: " << result;
1897 void ByteCodeExecutor::executeGetResults() {
1898 LDBG() <<
"Executing GetResults:";
1899 unsigned index = read<uint32_t>();
1901 ByteCodeField rangeIndex = read();
1903 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1904 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1907 LDBG() <<
" * Invalid result range";
1908 memory[read()] = result;
1911 void ByteCodeExecutor::executeGetUsers() {
1912 LDBG() <<
"Executing GetUsers:";
1913 unsigned memIndex = read();
1914 unsigned rangeIndex = read();
1915 OwningOpRange &range = opRangeMemory[rangeIndex];
1916 memory[memIndex] = ⦥
1918 range = OwningOpRange();
1919 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1921 Value value = read<Value>();
1924 LDBG() <<
" * Value: " << value;
1934 LDBG() <<
" * Values (" << values->size()
1935 <<
"): " << llvm::interleaved(*values);
1939 for (
Value value : *values)
1941 range = OwningOpRange(users.size());
1945 LDBG() <<
" * Result: " << range.size() <<
" operations";
1948 void ByteCodeExecutor::executeGetValueType() {
1949 LDBG() <<
"Executing GetValueType:";
1950 unsigned memIndex = read();
1951 Value value = read<Value>();
1954 LDBG() <<
" * Value: " << value <<
"\n * Result: " << type;
1958 void ByteCodeExecutor::executeGetValueRangeTypes() {
1959 LDBG() <<
"Executing GetValueRangeTypes:";
1960 unsigned memIndex = read();
1961 unsigned rangeIndex = read();
1964 LDBG() <<
" * Values: <NULL>";
1965 memory[memIndex] =
nullptr;
1969 LDBG() <<
" * Values (" << values->size()
1970 <<
"): " << llvm::interleaved(*values)
1971 <<
"\n * Result: " << llvm::interleaved(values->
getType());
1972 typeRangeMemory[rangeIndex] = values->
getType();
1973 memory[memIndex] = &typeRangeMemory[rangeIndex];
1976 void ByteCodeExecutor::executeIsNotNull() {
1977 LDBG() <<
"Executing IsNotNull:";
1978 const void *value = read<const void *>();
1980 LDBG() <<
" * Value: " << value;
1981 selectJump(value !=
nullptr);
1984 void ByteCodeExecutor::executeRecordMatch(
1987 LDBG() <<
"Executing RecordMatch:";
1988 unsigned patternIndex = read();
1990 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1995 LDBG() <<
" * Benefit: Impossible To Match";
2004 unsigned numMatchLocs = read();
2006 matchLocs.reserve(numMatchLocs);
2007 for (
unsigned i = 0; i != numMatchLocs; ++i)
2008 matchLocs.push_back(read<Operation *>()->getLoc());
2011 LDBG() <<
" * Benefit: " << benefit.
getBenefit();
2012 LDBG() <<
" * Location: " << matchLoc;
2013 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2019 unsigned numInputs = read();
2020 match.values.reserve(numInputs);
2021 match.typeRangeValues.reserve(numInputs);
2022 match.valueRangeValues.reserve(numInputs);
2023 for (
unsigned i = 0; i < numInputs; ++i) {
2024 switch (read<PDLValue::Kind>()) {
2025 case PDLValue::Kind::TypeRange:
2026 match.typeRangeValues.push_back(*read<TypeRange *>());
2027 match.values.push_back(&match.typeRangeValues.back());
2029 case PDLValue::Kind::ValueRange:
2030 match.valueRangeValues.push_back(*read<ValueRange *>());
2031 match.values.push_back(&match.valueRangeValues.back());
2034 match.values.push_back(read<const void *>());
2042 LDBG() <<
"Executing ReplaceOp:";
2047 LDBG() <<
" * Operation: " << *op
2048 <<
"\n * Values: " << llvm::interleaved(args);
2052 void ByteCodeExecutor::executeSwitchAttribute() {
2053 LDBG() <<
"Executing SwitchAttribute:";
2055 ArrayAttr cases = read<ArrayAttr>();
2056 handleSwitch(value, cases);
2059 void ByteCodeExecutor::executeSwitchOperandCount() {
2060 LDBG() <<
"Executing SwitchOperandCount:";
2062 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2064 LDBG() <<
" * Operation: " << *op;
2068 void ByteCodeExecutor::executeSwitchOperationName() {
2069 LDBG() <<
"Executing SwitchOperationName:";
2071 size_t caseCount = read();
2077 const ByteCodeField *prevCodeIt = curCodeIt;
2078 LDBG() <<
" * Value: " << value <<
"\n * Cases: "
2079 << llvm::interleaved(
2080 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](
size_t) {
2081 return read<OperationName>();
2083 curCodeIt = prevCodeIt;
2087 for (
size_t i = 0; i != caseCount; ++i) {
2088 if (read<OperationName>() == value) {
2089 curCodeIt += (caseCount - i - 1);
2090 return selectJump(i + 1);
2093 selectJump(
size_t(0));
2096 void ByteCodeExecutor::executeSwitchResultCount() {
2097 LDBG() <<
"Executing SwitchResultCount:";
2099 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2101 LDBG() <<
" * Operation: " << *op;
2105 void ByteCodeExecutor::executeSwitchType() {
2106 LDBG() <<
"Executing SwitchType:";
2107 Type value = read<Type>();
2108 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2109 handleSwitch(value, cases);
2112 void ByteCodeExecutor::executeSwitchTypes() {
2113 LDBG() <<
"Executing SwitchTypes:";
2115 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2117 LDBG() <<
"Types: <NULL>";
2118 return selectJump(
size_t(0));
2120 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2121 return value == caseValue.getAsValueRange<TypeAttr>();
2128 std::optional<Location> mainRewriteLoc) {
2131 LDBG() << readInline<Location>();
2133 OpCode opCode =
static_cast<OpCode
>(read());
2135 case ApplyConstraint:
2136 executeApplyConstraint(rewriter);
2139 if (
failed(executeApplyRewrite(rewriter)))
2145 case AreRangesEqual:
2146 executeAreRangesEqual();
2151 case CheckOperandCount:
2152 executeCheckOperandCount();
2154 case CheckOperationName:
2155 executeCheckOperationName();
2157 case CheckResultCount:
2158 executeCheckResultCount();
2161 executeCheckTypes();
2166 case CreateConstantTypeRange:
2167 executeCreateConstantTypeRange();
2169 case CreateOperation:
2170 executeCreateOperation(rewriter, *mainRewriteLoc);
2172 case CreateDynamicTypeRange:
2173 executeDynamicCreateRange<Type>(
"Type");
2175 case CreateDynamicValueRange:
2176 executeDynamicCreateRange<Value>(
"Value");
2179 executeEraseOp(rewriter);
2182 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2185 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2188 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2198 executeGetAttribute();
2200 case GetAttributeType:
2201 executeGetAttributeType();
2204 executeGetDefiningOp();
2210 unsigned index = opCode - GetOperand0;
2211 LDBG() <<
"Executing GetOperand" << index <<
":";
2212 executeGetOperand(index);
2216 LDBG() <<
"Executing GetOperandN:";
2217 executeGetOperand(read<uint32_t>());
2220 executeGetOperands();
2226 unsigned index = opCode - GetResult0;
2227 LDBG() <<
"Executing GetResult" << index <<
":";
2228 executeGetResult(index);
2232 LDBG() <<
"Executing GetResultN:";
2233 executeGetResult(read<uint32_t>());
2236 executeGetResults();
2242 executeGetValueType();
2244 case GetValueRangeTypes:
2245 executeGetValueRangeTypes();
2252 "expected matches to be provided when executing the matcher");
2253 executeRecordMatch(rewriter, *matches);
2256 executeReplaceOp(rewriter);
2258 case SwitchAttribute:
2259 executeSwitchAttribute();
2261 case SwitchOperandCount:
2262 executeSwitchOperandCount();
2264 case SwitchOperationName:
2265 executeSwitchOperationName();
2267 case SwitchResultCount:
2268 executeSwitchResultCount();
2271 executeSwitchType();
2274 executeSwitchTypes();
2285 state.memory[0] = op;
2288 ByteCodeExecutor executor(
2289 matcherByteCode.data(), state.memory, state.opRangeMemory,
2290 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2291 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2292 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2293 constraintFunctions, rewriteFunctions);
2294 LogicalResult executeResult = executor.execute(rewriter, &matches);
2295 (void)executeResult;
2296 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2299 llvm::stable_sort(matches,
2300 [](
const MatchResult &lhs,
const MatchResult &rhs) {
2301 return lhs.benefit > rhs.benefit;
2306 const MatchResult &match,
2308 auto *configSet =
match.pattern->getConfigSet();
2310 configSet->notifyRewriteBegin(rewriter);
2316 ByteCodeExecutor executor(
2317 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2318 state.opRangeMemory, state.typeRangeMemory,
2319 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2320 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2321 rewriterByteCode, state.currentPatternBenefits,
patterns,
2322 constraintFunctions, rewriteFunctions);
2323 LogicalResult result =
2324 executor.execute(rewriter,
nullptr,
match.location);
2327 configSet->notifyRewriteEnd(rewriter);
2336 LDBG() <<
" and rollback is not supported - aborting";
2337 llvm::report_fatal_error(
2338 "Native PDL Rewrite failed, but the pattern "
2339 "rewriter doesn't support recovery. Failable pattern rewrites should "
2340 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...