23 #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
24 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
34 class RewriteInsertsPass
35 :
public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
37 void runOnOperation()
override;
44 collectInsertionChain(spirv::CompositeInsertOp op,
50 void RewriteInsertsPass::runOnOperation() {
52 getOperation().walk([
this, &workList](spirv::CompositeInsertOp op) {
54 if (
succeeded(collectInsertionChain(op, insertions)))
55 workList.push_back(insertions);
58 for (
const auto &insertions : workList) {
59 auto lastCompositeInsertOp = insertions.back();
60 auto compositeType = lastCompositeInsertOp.getType();
61 auto location = lastCompositeInsertOp.getLoc();
65 for (
auto insertionOp : insertions)
66 operands.push_back(insertionOp.getObject());
69 auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
70 location, compositeType, operands);
72 lastCompositeInsertOp.replaceAllUsesWith(
73 compositeConstructOp->getResult(0));
76 for (
auto insertOp : llvm::reverse(insertions)) {
77 auto *op = insertOp.getOperation();
85 spirv::CompositeInsertOp op,
87 auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
89 if (indicesArrayAttr.size() == 1) {
90 auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
93 auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
95 if (index + 1 != numElements)
98 insertions.resize(numElements);
100 insertions[index] = op;
105 op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
110 indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
111 if ((indicesArrayAttr.size() != 1) ||
112 (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
This class helps build Operations.
bool use_empty()
Returns true if this operation has no uses.
void erase()
Remove this operation from its parent block and delete it.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.