48 return enc && !llvm::all_of(enc.getLvlTypes(),
49 [](
auto lt) { return lt == LevelType::Dense; });
55 Value val = op->get();
73 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
74 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
78 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
79 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
87 if (
auto arg = dyn_cast<BlockArgument>(val))
90 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
100 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
101 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
103 return (def->getOperand(0) == x &&
isMulChain(def->getOperand(1), x)) ||
104 (def->getOperand(1) == x &&
isMulChain(def->getOperand(0), x));
113 if (
auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
114 if (arg.getOwner()->getParentOp() == op) {
125 for (
const auto &d :
enumerate(stp.getShape())) {
127 if (d.value() == ShapedType::kDynamic)
128 dim = builder.
create<tensor::DimOp>(loc, tensor, d.index());
131 sizes.push_back(dim);
146 for (
const auto &d :
enumerate(tp.getShape())) {
147 if (d.value() == ShapedType::kDynamic)
148 dynSizes.push_back(sizes[d.index()]);
154 SparseElementsAttr attr) {
160 rewriter, loc, attr, op.getOrder().value_or(
AffineMap()),
163 args.append(cvs.begin(), cvs.end());
167 auto cloned = cast<ForeachOp>(rewriter.
clone(*op.getOperation()));
168 assert(args.size() == cloned.getBody()->getNumArguments());
169 Operation *yield = cloned.getBody()->getTerminator();
173 reduc = yield->getOperands();
187 auto dstShape = dstTp.getShape();
191 if (dstShape[dim] != ShapedType::kDynamic) {
196 for (
const auto &src : srcs.drop_front()) {
199 sizes[dim] = builder.
create<arith::AddIOp>(loc, sizes[dim], srcSz);
225 rewriter.
replaceOp(op, op.getDpsInitOperand(0)->get());
229 if (!outputType.hasStaticShape())
231 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
260 if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 ||
262 op.getNumParallelLoops() != op.getNumLoops() ||
263 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
264 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
265 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
277 auto prod = dyn_cast_or_null<GenericOp>(
278 op.getDpsInputOperand(other)->get().getDefiningOp());
279 if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
280 !prod.getResult(0).hasOneUse())
292 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
293 fusedIndexMaps.push_back(fusedIndexMaps.back());
295 auto fusedOp = rewriter.
create<GenericOp>(
303 fusedOp.getRegion().push_back(fusedBlock);
304 unsigned num = prodBlock.getNumArguments();
305 for (
unsigned i = 0; i < num - 1; i++)
306 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
307 addArg(mapper, fusedBlock, consBlock.
getArgument(1 - other));
308 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
310 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
314 for (
auto &op : prodBlock.without_terminator())
317 rewriter.
clone(op, mapper);
322 rewriter.
create<linalg::YieldOp>(loc, last);
326 Value init = prod.getDpsInitOperand(0)
331 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
336 rewriter.
replaceOp(op, fusedOp->getResults());
358 Type srcType = op.getSource().getType();
359 Type dstType = op.getDest().getType();
361 if (srcType == dstType) {
367 if (
Operation *def = op.getSource().getDefiningOp()) {
368 if (def->
hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
417 auto matched = isRewritablePattern(op, &inst);
418 if (!matched.has_value())
422 auto [c, t, f] = matched.value();
423 assert(t.getType() == f.getType());
424 auto selTp = t.getType();
426 auto binOp = rewriter.
create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
428 rewriter.
createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
429 {t.getLoc(), f.getLoc()});
430 rewriter.
createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
431 rewriter.
createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
433 for (
auto *r : binOp.getRegions()) {
441 if (
auto *def = c.getDefiningOp())
445 if (r == &binOp.getLeftRegion()) {
448 }
else if (r == &binOp.getRightRegion()) {
456 rewriter.
create<sparse_tensor::YieldOp>(loc, y);
462 semiRings.emplace_back(&inst, binOp);
466 for (
auto [sel, semi] : semiRings)
467 rewriter.
replaceOp(sel, semi->getResults());
469 return success(!semiRings.empty());
473 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
474 isRewritablePattern(GenericOp op,
Operation *v) {
475 auto sel = dyn_cast<arith::SelectOp>(v);
489 auto isValFromDenseInputOrInvariant = [&op](
Value v) ->
bool {
491 bArg && !
isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
494 return v.getDefiningOp() && v.getDefiningOp()->
getBlock() != op.getBody();
499 auto cond = sel.getCondition();
500 if (isValFromDenseInputOrInvariant(cond))
501 return std::make_tuple(cond, tVal, fVal);
510 if (isValFromDenseInputOrInvariant(cmpL) ||
511 isValFromDenseInputOrInvariant(cmpR))
512 return std::make_tuple(cond, tVal, fVal);
543 if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 1 ||
546 auto inp = op.getDpsInputOperand(0);
547 auto init = op.getDpsInitOperand(0);
554 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
555 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
556 arith::MaxUIOp>(red))
560 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
561 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
573 auto semiring = rewriter.
create<sparse_tensor::UnaryOp>(loc, rtp, s0);
575 rewriter.
createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
577 rewriter.
create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
578 rewriter.
createBlock(&semiring.getAbsentRegion(), {}, {}, {});
582 rewriter.
create<sparse_tensor::YieldOp>(loc, zero);
587 auto custom = rewriter.
create<sparse_tensor::ReduceOp>(
588 loc, rtp, semiring.getResult(), s1, identity);
590 rewriter.
createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
593 irMap.
map(red->getOperand(0), region->getArgument(0));
594 irMap.
map(red->getOperand(1), region->getArgument(1));
595 auto cloned = rewriter.
clone(*red, irMap);
596 rewriter.
create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
598 rewriter.
replaceOp(red, custom.getResult());
611 Value srcTensor = op.getSource();
615 if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
616 !dstTp.hasStaticDimShape())
625 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
629 dstTp.withoutDimToLvl(),
630 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
632 Value buffer = rewriter
633 .
create<AllocTensorOp>(loc, bufferTp, dynSizes,
Value(),
648 const auto encSrc = srcTp.getEncoding();
649 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
650 loc, srcTensor, buffer,
653 const Dimension srcRank = srcTp.getDimRank();
655 srcDcvs.reserve(srcRank);
656 for (
Dimension d = 0; d < srcRank; d++) {
658 srcDcvs.push_back(srcLcvs[lvl]);
664 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
669 collapseIdx.push_back(i);
672 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
673 collapsedSizes, collapsedDcvs);
676 for (
Dimension i = 0; i < dstTp.getDimRank(); i++)
677 expandIdx.push_back(i);
680 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
683 auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
684 builder.create<sparse_tensor::YieldOp>(loc, t);
687 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
688 if (bufferTp != dstTp) {
689 auto dstRTT = dstTp.getRankedTensorType();
690 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
691 rewriter.
create<DeallocTensorOp>(loc, t);
700 template <
typename ReshapeOp>
708 Value srcTensor = op.getSrc();
711 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
720 if (dstTp.hasStaticDimShape()) {
726 op.getReassociationIndices());
728 if (shape == ShapedType::kDynamic)
729 dstDynSizes.push_back(dstSizes[idx]);
732 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
736 dstTp.withoutDimToLvl(),
737 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
741 .
create<AllocTensorOp>(loc, bufferTp, dstDynSizes,
Value(),
752 const auto encSrc = srcTp.getEncoding();
753 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
754 loc, srcTensor, buffer,
757 const Dimension dimRank = srcTp.getDimRank();
759 srcDcvs.reserve(dimRank);
760 for (
Dimension d = 0; d < dimRank; d++) {
762 srcDcvs.push_back(srcLcvs[lvl]);
765 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
766 srcDcvs, dstSizes, dstDcvs);
767 auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
768 builder.create<sparse_tensor::YieldOp>(loc, t);
771 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
772 if (bufferTp != dstTp) {
773 auto dstRTT = dstTp.getRankedTensorType();
774 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
775 rewriter.
create<DeallocTensorOp>(loc, t);
785 template <
typename ReshapeOp>
799 if (encDst && encSrc) {
806 auto convert = rewriter.
create<ConvertOp>(loc, denseTp, op.getSrc());
814 auto reshape = rewriter.
create<ReshapeOp>(loc, denseTp, op.getSrc(),
815 op.getReassociation());
816 Value convert = rewriter.
create<ConvertOp>(loc, rtp, reshape);
832 val = builder.
create<AllocTensorOp>(loc, rtt, dynSzs);
835 val = builder.
create<linalg::FillOp>(loc, c0, val).getResult(0);
840 val = builder.
create<tensor::InsertOp>(loc, v, val, crds);
845 return builder.
create<LoadOp>(loc, val,
true);
849 bool isSparse()
const {
860 std::optional<int64_t> dim = op.getConstantIndex();
862 if (!dim || !stt.hasEncoding())
865 if (stt.isPermutation()) {
867 toLvl(stt.getEncoding(), *dim));
879 for (Level l = 0; l < stt.getLvlRank(); l++) {
880 Value lvlSz = rewriter.
create<LvlOp>(loc, op.getSource(), l);
883 maxLvlCrds.push_back(maxLvlCrd);
886 AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
887 Value maxDimCrd = rewriter.
create<affine::AffineApplyOp>(
902 if (op.needsExtraSort())
903 op.
emitError(
"ConcatenateOp not staged");
907 const Dimension conDim = op.getDimension();
924 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
926 Value iterArg = dstBuf.val;
929 for (
Value input : op.getInputs()) {
932 foreachOp = rewriter.
create<ForeachOp>(
938 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
941 dstBuf.val = reduc.front();
942 if (!dstTp.isAllDense()) {
944 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
946 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
947 builder.create<scf::YieldOp>(loc, dstBuf.val);
949 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
950 dstBuf.insert(builder, loc, v, offDimCrd);
951 builder.create<scf::YieldOp>(loc, dstBuf.val);
954 builder.setInsertionPointAfter(ifOp);
955 dstBuf.val = ifOp.getResult(0);
957 dstBuf.insert(builder, loc, v, offDimCrd);
959 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
965 assert(!ShapedType::isDynamic(sz));
966 offset = rewriter.
create<arith::AddIOp>(loc, offset,
969 dstBuf.val = iterArg;
972 dstBuf.val = iterArg;
973 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
983 if (op.needsExtraSort())
984 return op.
emitError(
"ConvertOp not staged.");
989 if (encDst && encSrc && !encSrc.isSlice() &&
990 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
997 Value src = op.getSource();
1002 bool fromSparseConst =
false;
1003 if (
auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1004 if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1005 fromSparseConst =
true;
1007 const AffineMapAttr foreachOrder =
1012 bool skipZeroCheck = srcStt.
hasEncoding() || fromSparseConst;
1019 auto foreachOp = rewriter.
create<ForeachOp>(
1020 loc, src, dstBuf.val, foreachOrder,
1024 dstBuf.val = reduc.front();
1025 if (!skipZeroCheck) {
1027 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1029 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1030 builder.create<scf::YieldOp>(loc, dstBuf.val);
1032 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1033 dstBuf.insert(builder, loc, v, dcvs);
1034 builder.create<scf::YieldOp>(loc, dstBuf.val);
1037 builder.setInsertionPointAfter(ifOp);
1038 dstBuf.val = ifOp.getResult(0);
1040 dstBuf.insert(builder, loc, v, dcvs);
1042 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1048 dstBuf.val = foreachOp.getResult(0);
1060 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1061 ? op.getEncoder().getDimToLvl()
1062 : op.getEncoder().getLvlToDim();
1069 Value trans = rewriter.
create<affine::AffineApplyOp>(
1072 outCrds.push_back(trans);
1088 Value input = op.getTensor();
1091 const Level lvlRank = stt.getLvlRank();
1095 if (
auto constOp = input.
getDefiningOp<arith::ConstantOp>()) {
1096 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1102 const auto enc = stt.getEncoding();
1109 for (Level l = 0; l < lvlRank; l++) {
1113 loopEmitter.makeTensorLevel(0, l)};
1114 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1117 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1122 if (op.getOrder()) {
1125 "Level order not yet implemented on non-constant input tensors.");
1128 Value vals = loopEmitter.getValBuffer()[0];
1129 Value pos = loopEmitter.getPosits()[0].back();
1132 Value val = enc ? rewriter.
create<memref::LoadOp>(loc, vals, pos)
1133 : rewriter.
create<memref::LoadOp>(loc, vals, lcvs);
1136 Block *srcBlock = op.getBody();
1140 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1143 args.push_back(val);
1152 if (!reducValue.empty()) {
1162 for (Level l = 0; l < lvlRank; l++) {
1165 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1166 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1183 if (!stt.hasEncoding() || stt.getCOOStart() == 0)
1190 RankedTensorType dstTp = stt.getRankedTensorType();
1191 RankedTensorType cooTp = stt.getCOOType(
true);
1193 Value convert = cooTensor;
1194 auto enc = stt.getEncoding();
1195 if (!stt.isPermutation()) {
1197 convert = rewriter.
create<ReinterpretMapOp>(loc, coo, convert);
1200 convert = rewriter.
create<ConvertOp>(loc, dstTp, convert);
1201 if (!stt.isPermutation())
1202 convert = rewriter.
create<ReinterpretMapOp>(loc, enc, convert);
1207 rewriter.
create<DeallocTensorOp>(loc, cooTensor);
1220 Value src = op.getTensor();
1221 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, src);
1225 const Dimension dimRank = srcTp.getDimRank();
1233 for (
Dimension d = 0; d < dimRank; d++) {
1234 rewriter.
create<memref::StoreOp>(loc, dims[d], dimSizes,
1241 createFuncCall(rewriter, loc,
"createSparseTensorWriter", {opaqueTp},
1242 {op.getDest()}, EmitCInterface::Off)
1245 createFuncCall(rewriter, loc,
"outSparseTensorWriterMetaData", {},
1246 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1248 Value dimCoords = dimSizes;
1249 Type eltTp = srcTp.getElementType();
1256 rewriter.
create<ForeachOp>(
1257 loc, src, std::nullopt,
1260 for (
Dimension d = 0; d < dimRank; d++) {
1261 rewriter.
create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1264 rewriter.
create<memref::StoreOp>(loc, v, value);
1267 EmitCInterface::On);
1268 builder.create<func::CallOp>(loc,
TypeRange(), fn, operands);
1269 builder.create<sparse_tensor::YieldOp>(loc);
1273 createFuncCall(rewriter, loc,
"delSparseTensorWriter", {}, {writer},
1274 EmitCInterface::Off);
1288 patterns.
add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1289 GenSemiRingReduction, GenSemiRingSelect>(patterns.
getContext());
1294 bool enableConvert) {
1295 patterns.
add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1296 ReshapeRewriter<tensor::CollapseShapeOp>,
1297 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1298 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1299 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1311 patterns.
add<CrdTranslateRewriter, ForeachRewriter>(patterns.
getContext());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
TypedAttr getZeroAttr(Type type)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...