22 #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
23 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33 class RewriteInsertsPass
34 :
public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
36 void runOnOperation()
override;
43 collectInsertionChain(spirv::CompositeInsertOp op,
49 void RewriteInsertsPass::runOnOperation() {
51 getOperation().walk([
this, &workList](spirv::CompositeInsertOp op) {
53 if (succeeded(collectInsertionChain(op, insertions)))
54 workList.push_back(insertions);
57 for (
const auto &insertions : workList) {
58 auto lastCompositeInsertOp = insertions.back();
59 auto compositeType = lastCompositeInsertOp.getType();
60 auto location = lastCompositeInsertOp.getLoc();
64 for (
auto insertionOp : insertions)
65 operands.push_back(insertionOp.getObject());
68 auto compositeConstructOp = spirv::CompositeConstructOp::create(
69 builder, location, compositeType, operands);
71 lastCompositeInsertOp.replaceAllUsesWith(
72 compositeConstructOp->getResult(0));
75 for (
auto insertOp : llvm::reverse(insertions)) {
76 auto *op = insertOp.getOperation();
83 LogicalResult RewriteInsertsPass::collectInsertionChain(
84 spirv::CompositeInsertOp op,
86 if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
89 auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
91 if (indicesArrayAttr.size() == 1) {
92 auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
95 auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
97 if (index + 1 != numElements)
100 insertions.resize(numElements);
102 insertions[index] = op;
107 op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
112 indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
113 if ((indicesArrayAttr.size() != 1) ||
114 (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
This class helps build Operations.
Include the generated interface declarations.