16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SmallVector.h"
57 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
58 using OpRewritePatternWithSymbolTableCollection::
59 OpRewritePatternWithSymbolTableCollection;
60 LogicalResult matchAndRewrite(GridShapeOp op,
63 GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
64 op.getOperation(), op.getGridAttr());
68 ArrayRef<GridAxis> opGridAxes = op.getAxes();
70 if (opGridAxes.empty()) {
71 opAxesIota.resize(grid.getRank());
72 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
73 opGridAxes = opAxesIota;
75 if (llvm::all_of(opGridAxes, [&grid](
GridAxis axis) {
76 return ShapedType::isDynamic(grid.getShape()[axis]);
82 SmallVector<Value> newResults(op->getResults().size());
83 SmallVector<GridAxis> newShapeOpGridAxes;
84 SmallVector<size_t> newToOldResultsIndexMap;
86 for (
size_t i = 0; i < opGridAxes.size(); ++i) {
87 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
88 if (ShapedType::isDynamic(gridAxisSize)) {
89 newToOldResultsIndexMap.push_back(i);
90 newShapeOpGridAxes.push_back(opGridAxes[i]);
93 newResults[i] = arith::ConstantOp::create(
94 builder, builder.getIndexAttr(gridAxisSize));
99 if (!newShapeOpGridAxes.empty()) {
100 GridShapeOp newShapeOp =
101 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
102 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
103 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
116 patterns.add<GridShapeFolder>(symbolTableCollection,
patterns.getContext());
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns