38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/ADT/SetOperations.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/InterleavedRange.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include "llvm/Support/MathExtras.h"
48 #include "llvm/Support/raw_ostream.h"
58 auto type = cast<ShapedType>(v.
getType());
59 if (!type.isDynamicDim(dim))
64 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
65 return tensor::DimOp::create(builder, loc, v, dim);
67 .Case<MemRefType>([&](MemRefType t) ->
Value {
68 return memref::DimOp::create(builder, loc, v, dim);
79 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
80 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
83 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
84 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
96 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
98 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
100 llvm_unreachable(
"Expected MemRefType or TensorType");
105 auto shapedType = llvm::cast<ShapedType>(source.
getType());
106 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
131 for (
auto containers : {inputTypes, outputTypes}) {
132 for (
auto t : containers) {
144 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
148 regionBuilder(b, *body, attrs,
emitError);
160 std::optional<TypeRange> resultTensorTypes,
167 if (!resultTensorTypes)
168 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
169 llvm::IsaPred<RankedTensorType>);
171 state.addOperands(inputs);
172 state.addOperands(outputs);
173 state.addTypes(derivedResultTypes);
175 state.addAttributes(attributes);
177 "operandSegmentSizes",
179 static_cast<int32_t>(outputs.size())}));
182 Region ®ion = *state.addRegion();
184 state.attributes.getAttrs(), {},
189 std::optional<TypeRange> resultTensorTypes,
196 indexingMapsAttrVal = llvm::map_to_vector(
197 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
199 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
201 attributes, regionBuilder);
205 std::optional<TypeRange> resultTensorTypes,
212 indexingMapsAttrVal =
216 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
218 attributes, regionBuilder);
222 std::optional<TypeRange> resultTensorTypes,
229 indexingMapsAttrVal =
233 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
235 attributes, regionBuilder);
244 bool addOperandSegmentSizes =
true) {
245 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
274 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
276 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
280 if (addOperandSegmentSizes) {
289 attrs.
append(
"operandSegmentSizes",
291 {static_cast<int32_t>(inputsOperands.size()),
292 static_cast<int32_t>(outputsOperands.size())}));
297 {static_cast<int32_t>(inputsOperands.size()),
298 static_cast<int32_t>(outputsOperands.size())}));
302 std::optional<RegisteredOperationName> info =
305 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
306 return parser.emitError(attrsLoc)
307 <<
"'" << result.name.getStringRef() <<
"' op ";
318 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
319 if (!outputs.empty())
320 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
331 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
335 "region expects {0} args, got {1}",
336 numRegionArgs, inputTypes.size() + outputTypes.size()));
340 ParseResult result = success();
342 opBuilder, region, inputTypes, outputTypes, attrs,
361 unsigned numRegionArgs,
378 result.
addTypes(outputTensorsTypes);
380 std::unique_ptr<Region> region = std::make_unique<Region>();
392 if (resultTypes.empty())
437 class RegionBuilderHelper {
440 : builder(builder), block(block) {}
445 if (!isFloatingPoint(arg)) {
447 emitError() <<
"unsupported non numeric type";
450 llvm_unreachable(
"unsupported non numeric type");
453 builder.setInsertionPointToEnd(&block);
456 return math::ExpOp::create(builder, arg.
getLoc(), arg);
458 return math::LogOp::create(builder, arg.
getLoc(), arg);
460 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
462 return math::CeilOp::create(builder, arg.
getLoc(), arg);
464 return math::FloorOp::create(builder, arg.
getLoc(), arg);
466 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
467 case UnaryFn::reciprocal: {
469 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
470 ::cast<TypedAttr>(oneAttr));
471 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
474 return math::RoundOp::create(builder, arg.
getLoc(), arg);
476 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
478 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
479 case UnaryFn::square:
480 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
482 return math::TanhOp::create(builder, arg.
getLoc(), arg);
484 return math::ErfOp::create(builder, arg.
getLoc(), arg);
487 emitError() <<
"unsupported unary function";
490 llvm_unreachable(
"unsupported unary function");
499 bool allComplex = isComplex(arg0) && isComplex(arg1);
500 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
501 bool allInteger = isInteger(arg0) && isInteger(arg1);
504 if (!allComplex && !allFloatingPoint && !allInteger) {
507 <<
"Cannot build binary Linalg operation: expects allComplex, "
508 "allFloatingPoint, or allInteger, got "
512 llvm_unreachable(
"unsupported non numeric type");
515 builder.setInsertionPointToEnd(&block);
519 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
520 if (allFloatingPoint)
521 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
523 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
524 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
527 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
528 if (allFloatingPoint)
529 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
532 emitError() <<
"unsupported operation: sub with bools";
535 llvm_unreachable(
"unsupported operation: sub with bools");
537 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
540 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
541 if (allFloatingPoint)
542 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
544 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
545 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
548 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
549 if (allFloatingPoint)
550 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
553 emitError() <<
"unsupported operation: div with bools";
556 llvm_unreachable(
"unsupported operation: div with bools");
558 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
559 case BinaryFn::div_unsigned:
560 if (!allInteger || allBool) {
562 emitError() <<
"unsupported operation: unsigned div not on uint";
565 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
567 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
568 case BinaryFn::max_signed:
570 if (allFloatingPoint)
571 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
572 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
573 case BinaryFn::min_signed:
575 if (allFloatingPoint)
576 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
577 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
578 case BinaryFn::max_unsigned:
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
582 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
583 case BinaryFn::min_unsigned:
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
587 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
589 assert(allFloatingPoint);
590 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
593 emitError() <<
"unsupported binary function";
596 llvm_unreachable(
"unsupported binary function");
604 bool tailFloatingPoint =
605 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
606 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
608 builder.setInsertionPointToEnd(&block);
610 case TernaryFn::select:
611 if (!headBool && !(tailFloatingPoint || tailInteger))
612 llvm_unreachable(
"unsupported non numeric type");
613 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
616 emitError() <<
"unsupported ternary function";
619 llvm_unreachable(
"unsupported ternary function");
626 case TypeFn::cast_signed:
627 return cast(toType, operand,
false);
628 case TypeFn::cast_unsigned:
629 return cast(toType, operand,
true);
632 emitError() <<
"unsupported type conversion function";
635 llvm_unreachable(
"unsupported type conversion function");
640 builder.setInsertionPointToEnd(&block);
641 Location loc = builder.getUnknownLoc();
642 YieldOp::create(builder, loc, values);
645 Value constant(
const std::string &value) {
647 builder.setInsertionPointToEnd(&block);
648 Location loc = builder.getUnknownLoc();
650 return arith::ConstantOp::create(builder, loc,
651 ::cast<TypedAttr>(valueAttr));
654 Value index(int64_t dim) {
656 builder.setInsertionPointToEnd(&block);
657 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
660 Type getIntegerType(
unsigned width) {
674 builder.setInsertionPointToEnd(&block);
675 auto loc = operand.
getLoc();
676 if (isa<UnknownLoc>(loc)) {
686 bool isComplex(
Value value) {
687 return llvm::isa<ComplexType>(value.
getType());
689 bool isFloatingPoint(
Value value) {
690 return llvm::isa<FloatType>(value.
getType());
692 bool isInteger(
Value value) {
693 return llvm::isa<IntegerType>(value.
getType());
710 LogicalResult matchAndRewrite(CopyOp copyOp,
712 if (copyOp.getInputs() != copyOp.getOutputs())
714 if (copyOp.hasPureBufferSemantics())
717 rewriter.
replaceOp(copyOp, copyOp.getInputs());
727 results.
add<EraseSelfCopy>(context);
740 template <
typename TensorReshapeOp>
743 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
745 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
750 TensorReshapeOp newInit;
751 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
753 newInit = TensorReshapeOp::create(
754 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
755 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
756 reshapeOp.getStaticOutputShape());
758 newInit = TensorReshapeOp::create(
759 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
760 reshapeOp.getReassociation());
773 LogicalResult matchAndRewrite(tensor::PadOp padOp,
775 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
781 Value padValue = padOp.getConstantPaddingValue();
782 if (!padValue || fillOp.value() != padValue)
788 padOp,
"failed to reify tensor.pad op result shape");
791 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
792 padOp.getResultType().getElementType());
794 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
797 if (replacement.
getType() != padOp.getResultType()) {
798 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
799 padOp.getResultType(), replacement);
809 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
812 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
814 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
818 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
823 Value firstDest = insertOp.getDest();
824 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
825 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
830 bool disjoint =
false;
831 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
834 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
835 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
836 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
840 int64_t prevStart = prevOp.getStaticOffset(i);
841 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
842 prevOp.getStaticStride(i);
843 int64_t nextStart = insertOp.getStaticOffset(i);
844 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
845 insertOp.getStaticStride(i);
846 if (prevEnd < nextStart || nextEnd < prevStart) {
854 firstDest = prevOp.getDest();
865 Value padValue = srcPadOp.getConstantPaddingValue();
866 if (!padValue || dstFillOp.value() != padValue)
882 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
884 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
887 RankedTensorType srcPadType = srcPadOp.getSourceType();
889 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
890 if (srcPadType.isDynamicDim(i)) {
892 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
895 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
900 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
901 newSizes, insertOp.getMixedStrides());
907 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
911 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
915 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
920 Value extractedScalar = fillOp.getInputs()[0];
923 rewriter.
replaceOp(extractOp, extractedScalar);
931 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
932 linalg::PackOp packOp) {
933 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
937 if (
auto paddingValue = packOp.getPaddingValue())
941 Value packOpDest = packOp.getDest();
945 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
955 LogicalResult matchAndRewrite(linalg::PackOp packOp,
957 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
960 rewriter.
replaceOp(packOp, fillOp.value().result());
969 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
971 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
974 copyOp.getOutputs());
977 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
979 fillOp.getOutputs());
990 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
992 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
994 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
995 transposeOp.getDpsInitOperand(0)->get());
1007 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1009 auto concatOperands = concatOp.getInputs();
1010 if (concatOperands.empty()) {
1014 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1023 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1025 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
1026 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1033 if (fillVal != firstFillVal)
1036 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1039 if (!llvm::all_of(concatOperands.drop_front(),
1040 isDefinedByCompatibleFillOp)) {
1042 concatOp,
"not all operands are defined by a compatible fill op");
1045 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1046 concatOp.getDim(), allOuts);
1048 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1057 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1058 FoldFillWithPack, FoldFillWithPad,
1059 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1060 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1061 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1074 for (
ValueRange container : {inputs, outputs}) {
1075 for (
Value v : container) {
1076 Type t = v.getType();
1077 blockArgTypes.push_back(
1079 blockArgLocs.push_back(v.getLoc());
1085 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1089 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1091 for (
Value v : getRegionInputArgs())
1093 for (
Value v : getRegionOutputArgs())
1094 setNameFn(v,
"out");
1097 void GenericOp::build(
1100 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1103 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1104 iteratorTypes, doc, libraryCall);
1108 inputs, outputs, bodyBuild);
1111 void GenericOp::build(
1115 StringRef libraryCall,
1118 build(builder, result, resultTensorTypes, inputs, outputs,
1123 return IteratorTypeAttr::get(builder.getContext(), iter);
1126 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1127 bodyBuild, attributes);
1130 void GenericOp::build(
1134 StringRef libraryCall,
1137 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1138 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1141 void GenericOp::build(
1147 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1149 "", bodyBuild, attributes);
1152 void GenericOp::build(
1158 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1161 "", bodyBuild, attributes);
1168 auto genericAttrNames = linalgTraitAttrNames();
1171 genericAttrNamesSet.insert_range(genericAttrNames);
1173 for (
auto attr : (*this)->getAttrs()) {
1174 if (attr.getName() == getIteratorTypesAttrName()) {
1175 auto iteratorTypes =
1176 llvm::cast<ArrayAttr>(attr.getValue())
1177 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1183 llvm::to_vector(llvm::map_range(
1184 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1188 genericAttrs.emplace_back(
1189 getIteratorTypesAttrName(),
1191 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1192 genericAttrs.push_back(attr);
1195 if (!genericAttrs.empty()) {
1197 p << genericDictAttr;
1203 genericAttrNames.push_back(
"operandSegmentSizes");
1204 genericAttrNamesSet.insert(genericAttrNames.back());
1206 bool hasExtraAttrs =
false;
1208 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1211 if (hasExtraAttrs) {
1218 if (!getRegion().empty()) {
1228 DictionaryAttr dictAttr;
1237 dictAttr.getValue().end());
1243 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1245 if (!iteratorTypes) {
1246 return parser.
emitError(attributeLocation)
1247 <<
"expected " << getIteratorTypesAttrName(result.
name)
1248 <<
" array attribute";
1253 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1254 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1255 if (!maybeIteratorType.has_value())
1257 <<
"unexpected iterator_type (" << s <<
")";
1259 iteratorTypeAttrs.push_back(
1276 std::unique_ptr<Region> region = std::make_unique<Region>();
1288 result.
addTypes(outputTensorsTypes);
1296 LinalgOp linalgOp) {
1297 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1298 if (!llvm::isa<MemRefType>(operand.
getType()))
1300 effects.emplace_back(
1305 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1306 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1308 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1319 void GenericOp::getEffects(
1329 if (!linalgOp.hasPureTensorSemantics())
1349 template <
typename OpTy>
1353 LogicalResult matchAndRewrite(OpTy linalgOp,
1356 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1361 Block &body = linalgOp->getRegion(0).
front();
1362 if (!llvm::hasSingleElement(body))
1364 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1369 if (linalgOp.hasPureBufferSemantics()) {
1370 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1371 linalgOp.getDpsInputOperand(0)->get() !=
1372 linalgOp.getDpsInitOperand(0)->get()) {
1374 linalgOp,
"expected single input and output to be the same value");
1377 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1378 if (!yieldArg || yieldArg.getOwner() != &body) {
1380 "cannot fold fill-like op");
1387 if (!linalgOp.hasPureTensorSemantics()) {
1389 linalgOp,
"mixed semantics is not supported yet");
1396 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1397 if (!yieldArg || yieldArg.getOwner() != &body)
1399 unsigned argumentNumber = yieldArg.getArgNumber();
1400 Value returnedArg = linalgOp->getOperand(argumentNumber);
1401 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1405 if (returnType != resultType) {
1410 returnedArg = sparse_tensor::ConvertOp::create(
1411 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1413 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1416 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1417 resultType, returnedArg);
1420 returnedArgs.push_back(returnedArg);
1423 if (returnedArgs.size() != linalgOp->getNumResults())
1425 rewriter.
replaceOp(linalgOp, returnedArgs);
1434 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1456 for (
Type outputType : outputTypes) {
1457 if (llvm::isa<RankedTensorType>(outputType))
1462 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1471 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1473 for (
Value v : getRegionInputArgs())
1478 if (!getResults().empty())
1479 setNameFn(getResults().front(),
"mapped");
1486 build(builder, result,
TypeRange{}, inputs, init);
1491 if (llvm::isa<RankedTensorType>(initType))
1496 inputs, {}, bodyBuild);
1503 bool initFirst =
false) {
1508 for (
auto &operand : operands) {
1510 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1517 payloadOpOperands.push_back(block.
getArguments().back());
1518 for (
const auto &arg : block.
getArguments().drop_back())
1519 payloadOpOperands.push_back(arg);
1528 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1535 std::optional<OperationName> payloadOpName;
1539 if (failed(operationName))
1543 payloadOpName = operationName.value();
1551 if (payloadOpName.has_value()) {
1589 for (
const auto &[operand, bbArg] :
1591 if (bbArg != operand)
1595 for (
const auto &[operand, bbArg] :
1597 if (bbArg != operand)
1606 std::string attrToElide;
1608 for (
const auto &attr : payloadOp->
getAttrs()) {
1610 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1611 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1612 attrToElide = attr.getName().str();
1613 elidedAttrs.push_back(attrToElide);
1622 Block *mapper = getBody();
1637 [&](
auto arg) { p.printRegionArgument(arg); });
1646 auto *bodyBlock = getBody();
1647 auto blockArgs = bodyBlock->getArguments();
1650 if (getInputs().size() != blockArgs.size())
1651 return emitOpError() <<
"expects number of operands to match the arity of "
1653 << getInputs().size() <<
" and " << blockArgs.size();
1656 for (
const auto &[bbArgType, inputArg] :
1657 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1658 auto inputElemType =
1659 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1660 if (bbArgType != inputElemType) {
1661 return emitOpError() <<
"expected element type of input " << inputElemType
1662 <<
" to match bbArg type " << bbArgType;
1667 auto outputShape = getInit().getType().getShape();
1669 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1670 if (inputElemShape != outputShape) {
1671 return emitOpError() <<
"expected shape of input (" << inputElemShape
1672 <<
") to match shape of output (" << outputShape
1681 int64_t rank = getInit().getType().getRank();
1685 ArrayAttr MapOp::getIndexingMaps() {
1687 int64_t rank = getInit().getType().getRank();
1688 int64_t numIndexingMaps = getOperands().size();
1693 void MapOp::getEffects(
1707 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1709 for (
Value v : getRegionInputArgs())
1711 for (
Value v : getRegionOutputArgs())
1712 setNameFn(v,
"init");
1715 void ReduceOp::getAsmResultNames(
1717 if (!getResults().empty())
1718 setNameFn(getResults().front(),
"reduced");
1721 void ReduceOp::build(
1726 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1730 for (
Value init : inits) {
1732 if (llvm::isa<RankedTensorType>(initType))
1738 inputs, inits, bodyBuild);
1743 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1745 utils::IteratorType::parallel);
1746 for (int64_t reductionDim : getDimensions())
1747 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1748 return iteratorTypes;
1751 ArrayAttr ReduceOp::getIndexingMaps() {
1753 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1760 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1761 affineMaps.push_back(resultMap);
1765 void ReduceOp::getEffects(
1777 StringRef attributeName) {
1786 std::optional<OperationName> payloadOpName;
1790 if (failed(operationName))
1794 payloadOpName = operationName.value();
1805 if (payloadOpName.has_value()) {
1825 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1829 Block *mapper = getBody();
1844 [&](
auto arg) { p.printRegionArgument(arg); });
1855 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1856 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1857 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1858 return emitOpError() <<
"expects all inputs to have the same shapes. "
1859 "Shape at input-index "
1861 <<
" is not equal to the shape at input-index 0.";
1864 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1865 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1866 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1867 return emitOpError() <<
"expects all outputs to have the same shapes. "
1868 "Shape at output-index "
1870 <<
" is not equal to the shape at output-index 0.";
1873 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1874 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1877 for (int64_t dimension : dimensionsRef) {
1878 if (dimension < 0 || dimension >= inputType.getRank()) {
1879 return emitOpError()
1880 <<
"dimensions for reduction should be in the range [0, "
1881 << inputType.getRank() - 1 <<
"].";
1883 dimensionsToReduce.insert(dimension);
1886 auto inputDims = inputType.getShape();
1887 auto initDims = initType.getShape();
1892 if (!dimensionsToReduce.count(en.index()))
1893 reducedInputDims.push_back(en.value());
1896 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1897 return emitOpError() <<
"number of dimensions after reduction "
1898 << reducedInputDims.size()
1899 <<
" doesn't match the init rank "
1900 << initType.getRank();
1903 if (reducedInputDims != initDims)
1904 return emitOpError() <<
"init dimensions [" << initDims
1905 <<
"] doesn't match input dimensions after reduction ["
1906 << reducedInputDims <<
"]";
1908 Block *block = getBody();
1910 return emitOpError()
1911 <<
"mismatching number of operands and block arguments";
1914 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1915 Type inputElementType =
1916 llvm::cast<ShapedType>(input.getType()).getElementType();
1917 if (inputElementType != bbArg.getType())
1918 return emitOpError()
1919 <<
"input element type " << inputElementType
1920 <<
" does not match corresponding block argument type "
1925 for (
auto [output, bbArg] : llvm::zip(
1926 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1927 auto outputElementType =
1928 llvm::cast<ShapedType>(output.getType()).getElementType();
1929 if (outputElementType != bbArg.getType())
1930 return emitOpError()
1931 <<
"output element type " << outputElementType
1932 <<
" does not match corresponding block argument type "
1948 linalg::YieldOp::create(b, loc, args[0]);
1963 if (llvm::isa<RankedTensorType>(initType))
1992 void TransposeOp::getAsmResultNames(
1994 if (!getResults().empty())
1995 setNameFn(getResults().front(),
"transposed");
2008 return emitOpError(
"permutation is not valid");
2010 auto inputType = getInput().getType();
2011 auto initType = getInit().getType();
2013 int64_t rank = inputType.getRank();
2015 if (rank != initType.getRank())
2016 return emitOpError() <<
"input rank " << rank
2017 <<
" does not match init rank " << initType.getRank();
2019 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2020 return emitOpError() <<
"size of permutation " << permutationRef.size()
2021 <<
" does not match the argument rank " << rank;
2023 auto inputDims = inputType.getShape();
2024 auto initDims = initType.getShape();
2026 for (int64_t i = 0; i < rank; ++i) {
2027 int64_t inputDim = inputDims[permutationRef[i]];
2028 int64_t initDim = initDims[i];
2030 if (inputDim != initDim) {
2031 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2032 <<
" doesn't match dim(input, permutation[" << i
2033 <<
"]) = " << inputDim;
2041 int64_t rank = getInit().getType().getRank();
2045 ArrayAttr TransposeOp::getIndexingMaps() {
2047 int64_t rank = getInit().getType().getRank();
2050 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2054 void TransposeOp::getEffects(
2064 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2067 if (!isa<TensorType>(getInput().
getType()))
2071 if (getPermutation().size() == 0) {
2072 result.push_back(getInput());
2077 result.push_back(getInput());
2090 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2091 if (!defTransposeOp)
2096 foldedPerms.reserve(perms.size());
2097 for (int64_t perm : perms)
2098 foldedPerms.push_back(defPerms[perm]);
2101 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2115 Value input = transposeOp.getInput();
2116 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2127 unsigned dimensionSize = dimensions.size();
2128 for (
unsigned i = 0; i < dimensionSize; ++i)
2129 resultDimensions.push_back(invertPerm[dimensions[i]]);
2132 Value broadcastInput = broadcastOp.getInput();
2133 Location loc = transposeOp.getLoc();
2136 auto broadcastInputTy =
2137 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2138 unsigned inputRank = broadcastInputTy.getRank();
2139 for (
unsigned i = 0; i < inputRank; ++i) {
2140 if (broadcastInputTy.isDynamicDim(i)) {
2141 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2145 broadcastInputTy.getDimSize(i)));
2150 Value transposeInit = tensor::EmptyOp::create(
2151 rewriter, transposeOp.getLoc(), transposeResultShapes,
2152 broadcastInputTy.getElementType());
2155 Value transposeResult =
2156 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2157 transposeInit, resultPerms)
2160 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2185 if (llvm::isa<RankedTensorType>(initType))
2214 void BroadcastOp::getAsmResultNames(
2216 if (!getResults().empty())
2217 setNameFn(getResults().front(),
"broadcasted");
2229 auto inputType = getInput().getType();
2230 auto initType = getInit().getType();
2232 int64_t inputRank = inputType.getRank();
2233 int64_t initRank = initType.getRank();
2235 auto inputShape = inputType.getShape();
2236 auto initShape = initType.getShape();
2238 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2239 return emitOpError() <<
"input rank plus added dimensions does not "
2240 "match init rank. input rank: "
2242 <<
", dimensions size: " << dimensionsRef.size()
2243 <<
", init rank: " << initRank;
2246 if (dim < 0 || dim >= initRank)
2247 return emitOpError() <<
"dimension " << idx
2248 <<
" is out of range. expected range: [0, "
2249 << initRank - 1 <<
"], got: " << dim;
2254 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2255 if (!llvm::is_contained(dimensionsRef, dim))
2256 dimMap.push_back(dim);
2259 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2262 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2263 return emitOpError() <<
"input dim " << inputDimIdx
2264 <<
" should match init dim " << initDimIdx
2265 <<
". input: " << inputShape[inputDimIdx]
2266 <<
", init: " << initShape[initDimIdx];
2273 int64_t rank = getInit().getType().getRank();
2277 ArrayAttr BroadcastOp::getIndexingMaps() {
2279 int64_t rank = getInit().getType().getRank();
2285 void BroadcastOp::getEffects(
2297 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2305 if (getNumOperands() > 0)
2306 p <<
' ' << getOperands();
2308 if (getNumOperands() > 0)
2309 p <<
" : " << getOperandTypes();
2324 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2325 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2326 return op.emitOpError(
"expected number of yield values (")
2327 << op.getNumOperands()
2328 <<
") to match the number of inits / outs operands of the enclosing "
2329 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2331 for (
OpOperand &opOperand : op->getOpOperands()) {
2333 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2335 if (isa<MemRefType, RankedTensorType>(elementType))
2337 if (opOperand.get().getType() != elementType)
2338 return op.emitOpError(
"type of yield operand ")
2339 << (opOperand.getOperandNumber() + 1) <<
" ("
2340 << opOperand.get().getType() <<
") doesn't match "
2341 <<
"the element type of the enclosing linalg.generic op ("
2342 << elementType <<
")";
2348 auto *parentOp = (*this)->getParentOp();
2349 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2350 return emitOpError(
"expected single non-empty parent region");
2352 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2355 return emitOpError(
"expected parent op with LinalgOp interface");
2363 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2365 return emitOpError(
"expected parent op with LinalgOp interface");
2366 if (linalgOp.getNumLoops() <= getDim())
2367 return emitOpError(
"expected dim (")
2368 << getDim() <<
") to be lower than the number of loops ("
2369 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2374 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2383 uint64_t dim = getDim();
2384 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2385 if (loopBounds[dim] == 1)
2393 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2395 #define GET_OP_CLASSES
2396 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2398 #define GET_OP_CLASSES
2399 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2400 #define GET_OP_CLASSES
2401 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2418 for (
unsigned i = 0; i < num; ++i)
2425 auto rangeA = llvm::make_range(a.begin(), a.end());
2426 auto rangeB = llvm::make_range(b.begin(), b.end());
2427 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2428 return llvm::to_vector<4>(concatRanges);
2432 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2434 for (
auto size : memref.getShape())
2441 if (
auto as = memref.getMemorySpace()) {
2442 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2443 ss <<
"as" << attr.getInt();
2449 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2452 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2465 assert(isa<LinalgOp>(op));
2467 std::string fun =
"";
2469 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2470 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2471 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2472 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2476 llvm::replace(name,
'.',
'_');
2477 llvm::raw_string_ostream ss(name);
2481 return std::string();
2496 LogicalResult matchAndRewrite(LinalgOp op,
2498 for (
OpOperand &opOperand : op->getOpOperands()) {
2502 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2505 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2516 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2519 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2524 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2531 if (castOp->getBlock() != linalgOp->getBlock())
2538 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2541 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2547 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2549 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2552 linalgOp.getDpsInits().end());
2553 outputOperands[resultNumber] = newOperand;
2554 newOperands.append(outputOperands.begin(), outputOperands.end());
2557 linalgOp->result_type_end());
2558 resultTypes[resultNumber] = resultType;
2559 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2562 Value castBack = tensor::CastOp::create(
2566 results[resultNumber] = castBack;
2578 if (linalgOp.isScalar(&opOperand))
2580 Value src = opOperand.get();
2581 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2582 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2590 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2591 Value castSource = castOp.getSource();
2592 auto castSourceType =
2593 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2594 if (castSourceType && castSourceType.hasStaticShape())
2595 sourceShape = castSourceType.getShape();
2601 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2602 if (sourceType.isDynamicDim(i))
2604 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2605 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2615 static void createNewOperandWithStaticSizes(
2619 bool &changeNeeded) {
2621 newOperands.push_back(src);
2622 if (linalgOp.isScalar(opOperand))
2624 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2625 Type resultType = sourceType;
2626 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2627 resultTypes.push_back(resultType);
2631 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2635 bool newOperandNeeded =
false;
2636 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2637 int64_t dimShape = sourceShape[i];
2639 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2640 newShape.push_back(dimShape);
2646 newShape.push_back(affineExprToSize[dimExpr]);
2647 newOperandNeeded =
true;
2650 sourceType.getEncoding());
2651 if (newOperandNeeded) {
2652 changeNeeded =
true;
2655 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2657 newOperands[index] = newOperand;
2659 if (linalgOp.isDpsInit(opOperand))
2660 resultTypes.push_back(resultType);
2669 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2671 if (!linalgOp.hasPureTensorSemantics())
2675 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2676 return !map.isProjectedPermutation();
2686 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2693 bool changeNeeded =
false;
2694 newOperands.reserve(linalgOp->getNumOperands());
2695 resultTypes.reserve(linalgOp.getNumDpsInits());
2698 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2699 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2700 affineExprToSize, linalgOp, newOperands,
2701 resultTypes, changeNeeded);
2710 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2713 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2714 Value newResult = std::get<1>(it);
2715 Value oldResult = std::get<0>(it);
2718 replacements.push_back(
2719 (newType != oldType)
2720 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2723 rewriter.
replaceOp(linalgOp, replacements);
2738 ShapedType inputType = getInputOperandType();
2739 ShapedType outputType = getOutputOperandType();
2744 return emitOpError(
"incompatible output shape");
2746 int64_t inputRank = getInputOperandRank();
2747 int64_t dimension = getDimension();
2748 if ((dimension < 0) || (dimension >= inputRank))
2749 return emitOpError(
"incorrect dimension specified");
2755 int64_t operandRank = getInputOperandRank();
2760 Value source = getInput();
2761 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2762 loopBounds[dim].offset = zero;
2763 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2764 loopBounds[dim].stride = one;
2771 utils::IteratorType::parallel);
2772 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2773 return iteratorTypes;
2776 FailureOr<TilingResult>
2780 int64_t rank = getInputOperandRank();
2785 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2787 return emitOpError(
"failed to compute input slice");
2789 tiledOperands.emplace_back(inputSlice->
getResult(0));
2791 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2793 return emitOpError(
"failed to compute output slice");
2795 tiledOperands.emplace_back(outputSlice->
getResult(0));
2798 if (hasPureTensorSemantics())
2799 resultTypes.push_back(tiledOperands[1].
getType());
2801 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2813 if (resultNumber == 0) {
2814 resultOffsets.assign(offsets.begin(), offsets.end());
2815 resultSizes.assign(sizes.begin(), sizes.end());
2830 Location loc = getOperation()->getLoc();
2832 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2833 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2834 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2835 if (!outputShapedType.isDynamicDim(dim)) {
2837 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2844 reifiedReturnShapes.emplace_back(std::move(shapes));
2848 void SoftmaxOp::getEffects(
2852 if (!llvm::isa<MemRefType>(operand.
getType()))
2855 &getOperation()->getOpOperand(index), 0,
2860 for (
OpOperand &operand : getDpsInitsMutable()) {
2861 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2894 int64_t dim,
bool allParallel =
false) {
2896 utils::IteratorType::parallel);
2898 iteratorTypes[dim] = utils::IteratorType::reduction;
2902 for (
int i = 0; i < inputRank; i++) {
2909 return std::make_tuple(iteratorTypes, indexingMaps);
2914 template <
typename T>
2917 auto inputType = cast<ShapedType>(input.
getType());
2919 int64_t inputRank = inputShape.size();
2920 auto [iteratorTypes, indexingMaps] =
2922 assert(indexingMaps.size() == 2 &&
2923 "We should have two maps: 1 for the input, 1 for the output");
2924 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2926 auto genericOp = linalg::GenericOp::create(
2927 builder, loc, output.
getType(), input, output, indexingMaps,
2929 Value result = T::create(b, loc, args[0], args[1]);
2930 linalg::YieldOp::create(b, loc, result);
2932 return genericOp.getResult(0);
2940 auto inputType = cast<ShapedType>(input.
getType());
2942 int64_t inputRank = inputShape.size();
2944 builder, inputRank, dim,
true);
2945 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2946 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2948 indexingMaps.push_back(indexingMaps[0]);
2949 auto genericOp = linalg::GenericOp::create(
2951 indexingMaps, iteratorTypes,
2953 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
2954 Value result = math::ExpOp::create(b, loc, diff);
2955 linalg::YieldOp::create(b, loc, result);
2957 return genericOp.getResult(0);
2966 Value denominator,
Value output, int64_t dim) {
2967 auto inputType = cast<ShapedType>(numerator.
getType());
2969 int64_t inputRank = inputShape.size();
2971 builder, inputRank, dim,
true);
2972 assert(indexingMaps.size() == 2 &&
2973 "We should have one map for each input (2)");
2974 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2976 indexingMaps.push_back(indexingMaps[0]);
2977 auto genericOp = linalg::GenericOp::create(
2979 output, indexingMaps, iteratorTypes,
2981 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
2982 linalg::YieldOp::create(b, loc, result);
2984 return genericOp.getResult(0);
3006 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
3010 Value input = getInput();
3011 ShapedType inputType = getInputOperandType();
3012 Type elementType = inputType.getElementType();
3013 int64_t reductionDim = getDimension();
3015 Value output = getOutput();
3016 dims.erase(dims.begin() + reductionDim);
3018 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3020 elementType, b, loc,
3022 Value neutralForMaxFInit =
3023 linalg::FillOp::create(b, loc,
Value{neutralForMaxF}, outputReduce)
3026 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3035 linalg::FillOp::create(b, loc,
Value{zero}, outputReduce).result();
3037 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3041 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3050 auto filterType = cast<ShapedType>(getFilter().
getType());
3052 int64_t filterH = filterShape[getFilterHDim()];
3053 int64_t filterW = filterShape[getFilterWDim()];
3054 WinogradConv2DFmr fmr = getFmr();
3058 if (filterH != r && filterH != 1)
3059 return emitOpError(
"expect filter height either equals to r or 1");
3060 if (filterW != r && filterW != 1)
3061 return emitOpError(
"expect filter width either equals to r or 1");
3062 if (filterH == 1 && filterW == 1)
3063 return emitOpError(
"expect either filter height or width equals to r");
3066 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3067 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3068 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3069 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3071 auto outputType = cast<ShapedType>(getOutput().
getType());
3074 return emitOpError(
"the output shape is not expected");
3080 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
3084 Value filter = getFilter();
3085 int64_t filterRank = getFilterOperandRank();
3087 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3088 loopBounds[dim].offset = zeroAttr;
3089 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3090 loopBounds[dim].stride = oneAttr;
3096 WinogradFilterTransformOp::getLoopIteratorTypes() {
3097 int64_t filterRank = getFilterOperandRank();
3099 utils::IteratorType::parallel);
3100 return iteratorTypes;
3108 ShapedType filterType = getFilterOperandType();
3110 int64_t filterH = filterShape[getFilterHDim()];
3111 int64_t filterW = filterShape[getFilterWDim()];
3112 WinogradConv2DFmr fmr = getFmr();
3115 int64_t alpha = m + r - 1;
3116 int64_t alphaH = filterH != 1 ? alpha : 1;
3117 int64_t alphaW = filterW != 1 ? alpha : 1;
3121 resultOffsets.append(
3122 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3124 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3140 ShapedType filterType = getFilterOperandType();
3142 int64_t filterH = filterShape[getFilterHDim()];
3143 int64_t filterW = filterShape[getFilterWDim()];
3149 sliceOffsets.append(
3150 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3151 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3152 sizes[getFilterCDim()]});
3153 int64_t filterRank = getFilterOperandRank();
3156 auto filterSlice = tensor::ExtractSliceOp::create(
3157 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3158 tiledOperands.emplace_back(filterSlice);
3165 int64_t outputRank = getOutputOperandRank();
3167 auto outputSlice = tensor::ExtractSliceOp::create(
3168 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3169 tiledOperands.emplace_back(outputSlice);
3172 resultTypes.push_back(tiledOperands[1].
getType());
3174 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3187 auto inputType = cast<ShapedType>(getInput().
getType());
3189 int64_t inputH = inputShape[getInputHDim()];
3190 int64_t inputW = inputShape[getInputWDim()];
3191 WinogradConv2DFmr fmr = getFmr();
3194 int64_t tileSize = m + r - 1;
3196 auto outputType = cast<ShapedType>(getOutput().
getType());
3198 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3199 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3202 if (ShapedType::isDynamic(inputH)) {
3203 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3204 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3206 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3207 expectedOutputShape[getOutputTileHDim()] =
3208 leftTransform ? (inputH - (r - 1)) / m : inputH;
3210 if (ShapedType::isDynamic(inputW)) {
3211 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3212 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3214 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3215 expectedOutputShape[getOutputTileWDim()] =
3216 rightTransform ? (inputW - (r - 1)) / m : inputW;
3218 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3219 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3222 return emitOpError(
"the output shape is not expected");
3228 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3232 Value output = getOutput();
3233 int64_t outputRank = getOutputOperandRank();
3235 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3236 loopBounds[dim].offset = zeroAttr;
3238 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3239 loopBounds[dim].stride = oneAttr;
3245 WinogradInputTransformOp::getLoopIteratorTypes() {
3246 int64_t outputRank = getOutputOperandRank();
3248 utils::IteratorType::parallel);
3249 return iteratorTypes;
3257 ShapedType outputType = getOutputOperandType();
3259 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3260 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3262 WinogradConv2DFmr fmr = getFmr();
3265 int64_t alpha = m + r - 1;
3266 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3267 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3272 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3273 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3274 offsets[getOutputCDim()]});
3275 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3276 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3277 sizes[getOutputCDim()]});
3288 FailureOr<TilingResult>
3293 WinogradConv2DFmr fmr = getFmr();
3297 ShapedType outputType = getOutputOperandType();
3299 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3300 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3304 auto identityAffineMap =
3306 auto offsetAffineMap =
3309 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3310 offsets[getOutputTileHDim()]);
3312 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3313 offsets[getOutputTileWDim()]);
3317 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3319 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3326 sliceOffsets.append(
3327 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3333 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3334 int64_t inputRank = getInputOperandRank();
3336 auto inputSlice = tensor::ExtractSliceOp::create(
3337 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3338 tiledOperands.emplace_back(inputSlice);
3345 int64_t outputRank = getOutputOperandRank();
3347 auto outputSlice = tensor::ExtractSliceOp::create(
3348 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3349 tiledOperands.emplace_back(outputSlice);
3352 resultTypes.push_back(tiledOperands[1].
getType());
3354 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3367 auto valueType = cast<ShapedType>(getValue().
getType());
3369 int64_t valueH = valueShape[getValueAlphaHDim()];
3370 int64_t valueW = valueShape[getValueAlphaWDim()];
3371 int64_t valueTileH = valueShape[getValueTileHDim()];
3372 int64_t valueTileW = valueShape[getValueTileWDim()];
3373 WinogradConv2DFmr fmr = getFmr();
3376 bool leftTransform = valueH != 1;
3377 bool rightTransform = valueW != 1;
3379 int64_t outputRank = getOutputOperandRank();
3381 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3382 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3384 if (valueH != (leftTransform ? m + r - 1 : 1))
3385 return emitOpError(
"expect input height equals to input tile size");
3386 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3388 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3389 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3391 if (valueW != (rightTransform ? m + r - 1 : 1))
3392 return emitOpError(
"expect input width equals to input tile size");
3393 expectedOutputShape[getOutputWDim()] =
3394 (rightTransform ? m : 1) * valueTileW;
3396 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3397 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3399 auto outputType = cast<ShapedType>(getOutput().
getType());
3402 return emitOpError(
"the output shape is not expected");
3408 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3412 Value value = getValue();
3413 int64_t valueRank = getValueOperandRank();
3415 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3416 loopBounds[dim].offset = zeroAttr;
3418 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3419 loopBounds[dim].stride = oneAttr;
3425 WinogradOutputTransformOp::getLoopIteratorTypes() {
3426 int64_t valueRank = getValueOperandRank();
3428 utils::IteratorType::parallel);
3429 return iteratorTypes;
3436 WinogradConv2DFmr fmr = getFmr();
3442 auto identityAffineMap =
3447 ShapedType valueType = getValueOperandType();
3449 int64_t valueH = valueShape[0];
3450 int64_t valueW = valueShape[1];
3452 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3453 offsets[getValueTileHDim()]);
3455 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3456 offsets[getValueTileWDim()]);
3458 builder, loc, affineMap, sizes[getValueTileHDim()]);
3460 builder, loc, affineMap, sizes[getValueTileWDim()]);
3470 resultOffsets.append(
3471 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3473 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3492 ShapedType valueType = getValueOperandType();
3494 int64_t alphaH = valueShape[getValueAlphaHDim()];
3495 int64_t alphaW = valueShape[getValueAlphaWDim()];
3499 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3500 offsets[getValueTileWDim()], offsets[getValueNDim()],
3501 offsets[getValueFDim()]});
3502 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3503 sizes[getValueTileWDim()], sizes[getValueNDim()],
3504 sizes[getValueFDim()]});
3505 int64_t valueRank = getValueOperandRank();
3507 auto valueSlice = tensor::ExtractSliceOp::create(
3508 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3509 tiledOperands.emplace_back(valueSlice);
3516 int64_t outputRank = getOutputOperandRank();
3518 auto outputSlice = tensor::ExtractSliceOp::create(
3519 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3520 tiledOperands.emplace_back(outputSlice);
3523 resultTypes.push_back(tiledOperands[1].
getType());
3525 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3544 llvm::set_union(explicitSet, defaultSet);
3545 return explicitSet == defaultSet;
3565 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3567 auto opIndexingMap = opIndexingMaps[opIndex];
3568 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3571 return matmulOp->emitOpError()
3572 <<
"Unexpected dim expression in map result.";
3575 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3576 return matmulOp->emitOpError()
3577 <<
"Invalid broadcast requested, should be (d2).";
3586 template <
typename OpTy>
3589 AffineMap defaultIndexingMap,
bool isLHS) {
3590 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3591 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3592 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3595 return batchVariantMatmulOp->emitOpError()
3596 <<
"Unexpected result dim expression (outside the set of default "
3601 return batchVariantMatmulOp->emitOpError()
3602 <<
"no. of result dim expressions exceeds 3.";
3604 auto hasValidBatchDim = [](
AffineMap map) {
3611 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3612 return batchVariantMatmulOp->emitOpError()
3613 <<
"Invalid broadcast requested.";
3614 }
else if (!hasValidBatchDim(opIndexingMap)) {
3615 return batchVariantMatmulOp->emitOpError()
3616 <<
"Invalid batch dimension expression.";
3624 template <
typename OpTy>
3627 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3628 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3629 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3630 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3633 return batchVariantMatmulOp->emitOpError()
3634 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3637 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3639 return batchVariantMatmulOp->emitOpError()
3640 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3644 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3645 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3646 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3647 outputMap.getResult(1).isFunctionOfDim(1) &&
3648 outputMap.getResult(2).isFunctionOfDim(2)
3649 : outputMap.getResult(0).isFunctionOfDim(1) &&
3650 outputMap.getResult(1).isFunctionOfDim(2);
3653 if (!areValidOutputResultDim(opIndexingMap)) {
3654 return batchVariantMatmulOp->emitOpError()
3655 <<
"Invalid output map result dimension.";
3664 template <
typename OpTy>
3665 static LogicalResult
3669 batchVariantMatmulOp.getIndexingMapsArray();
3671 batchVariantMatmulOp.getDefaultIndexingMaps(
3672 batchVariantMatmulOp->getContext());
3674 if (opIndexingMaps.size() != 3)
3675 return batchVariantMatmulOp->emitOpError()
3676 <<
"Indexing_map attribute must have 3 affine maps.";
3678 auto opIndexingMap = opIndexingMaps[opIndex];
3679 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3687 defaultIndexingMap, opIndex == 0)))
3697 if (m == 2 && r == 3)
3698 return WinogradConv2DFmr::F_2_3;
3699 if (m == 4 && r == 3)
3700 return WinogradConv2DFmr::F_4_3;
3701 if (m == 2 && r == 5)
3702 return WinogradConv2DFmr::F_2_5;
3703 return std::nullopt;
3708 case WinogradConv2DFmr::F_2_3:
3710 case WinogradConv2DFmr::F_4_3:
3712 case WinogradConv2DFmr::F_2_5:
3729 return indexingMaps;
3734 utils::IteratorType::parallel,
3735 utils::IteratorType::reduction};
3738 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3740 std::string MatmulOp::getLibraryCallName() {
3744 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3748 bool MatmulOp::hasUserDefinedMaps() {
3752 return defaultMaps != explicitMaps;
3761 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3766 "MatmulOp regionBuilder expects 3 args");
3767 RegionBuilderHelper helper(b, block);
3770 TypeFn castVal = TypeFn::cast_signed;
3771 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3772 return attr.
getName() ==
"cast";
3774 if (castIter != attrs.end()) {
3775 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3783 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3790 yields.push_back(value4);
3791 helper.yieldOutputs(yields);
3801 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3802 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3813 ArrayAttr arrayAttr;
3817 if (llvm::any_of(arrayAttr,
3818 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3820 <<
"element of indexing_maps array is not an affine_map";
3827 if (failed(indexingMapsAttr))
3830 if (*indexingMapsAttr ==
nullptr) {
3831 auto indexingMapAttrs = llvm::map_to_vector(
3832 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3837 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3839 MatmulOp::getRegionBuilder());
3844 MatmulOp::getDefaultIndexingMaps(
getContext()),
3846 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3847 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3849 std::array<StringRef, 3> elidedAttrs = {
3850 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3858 if (!hasUserDefinedMaps())
3861 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3872 void MatmulOp::getEffects(
3875 if (hasPureTensorSemantics())
3889 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3899 for (
auto result : outAffineMap.
getResults()) {
3900 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3901 assert(dimExpr &&
"affine_map is a projected permutation");
3902 dimsInOutput[dimExpr.getPosition()] =
true;
3906 for (
auto dimOccursInOutput : dimsInOutput)
3907 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3908 : utils::IteratorType::reduction);
3910 return iteratorTypes;
3913 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3920 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
3925 "ContractOp regionBuilder expects 3 args");
3926 RegionBuilderHelper helper(b, block);
3928 TypeFn castSignedness = TypeFn::cast_signed;
3929 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3930 return attr.
getName() ==
"cast";
3932 if (castIter != attrs.end()) {
3933 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3939 Value lhsAtOutType =
3940 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3941 Value rhsAtOutType =
3942 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3943 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
3945 if (!productAtOutType)
3951 helper.yieldOutputs({result});
3956 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3958 "expected 'indexing_maps' attribute");
3959 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3966 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3968 p, getOperation(), getInputs(), getOutputs(),
3969 {
"indexing_maps",
"operandSegmentSizes"});
3973 int iterationSpaceDims = -1;
3982 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3983 bool isInput) -> LogicalResult {
3986 return emitError(
"provided affine_map is not a projected permutation");
3989 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3991 return emitError(
"ranks of shaped operand and results of corresponding "
3992 "affine_map differ");
3994 return emitError(
"affine_map specifies shaped access while operand has "
3999 if (iterationSpaceDims == -1) {
4003 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4004 return emitError(
"iteration spaces of provided affine_maps differ");
4009 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4011 llvm_unreachable(
"affine_map is a projected permutation");
4014 inOccurrences[affineDimExpr.getPosition()] += 1;
4016 outOccurrences[affineDimExpr.getPosition()] += 1;
4022 for (
auto &&[affineMap, operandType, isInput] :
4023 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4025 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4029 bool hasContractingDim =
false;
4030 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4031 size_t inOccCount = inOccurrences[dimIndex];
4032 size_t outOccCount = outOccurrences[dimIndex];
4035 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4037 if (inOccCount == 0 && outOccCount == 0)
4038 return emitError() <<
"iteration space dim at index " << dimIndex
4039 <<
" not used to access any operand";
4050 if (inOccCount == 1 && outOccCount != 1)
4052 <<
"iteration space dim at index " << dimIndex
4053 <<
" is neither a contracting dim nor of parallel iteration type";
4056 if (!hasContractingDim)
4057 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4066 void ContractOp::getEffects(
4069 if (hasPureTensorSemantics())
4082 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
4086 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4087 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4088 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4089 return indexingMaps;
4094 utils::IteratorType::parallel, utils::IteratorType::parallel,
4095 utils::IteratorType::parallel, utils::IteratorType::reduction};
4098 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4100 std::string BatchMatmulOp::getLibraryCallName() {
4106 bool BatchMatmulOp::hasUserDefinedMaps() {
4110 return defaultMaps != explicitMaps;
4120 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
4122 "Expected less than 3 result dim expr.");
4123 bool isValid =
false;
4124 enum Indices { batchPos, mPos, nPos, kPos };
4142 void BatchMatmulOp::regionBuilder(
4146 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4151 "BatchMatmulOp regionBuilder expects 3 args");
4152 RegionBuilderHelper helper(b, block);
4155 TypeFn castVal = TypeFn::cast_signed;
4156 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4157 return attr.
getName() ==
"cast";
4159 if (castIter != attrs.end()) {
4160 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4165 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4166 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4167 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4169 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4170 yields.push_back(addVal);
4171 helper.yieldOutputs(yields);
4187 if (!isa<AffineMapAttr>(mapAttr)) {
4189 "expected affine map attribute");
4191 indexingMapsAttr.push_back(mapAttr);
4201 if (indexingMapsAttr.empty()) {
4202 indexingMapsAttr = llvm::map_to_vector(
4203 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4210 BatchMatmulOp::getNumRegionArgs(),
4211 BatchMatmulOp::getRegionBuilder());
4216 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4218 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4219 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4221 std::array<StringRef, 3> elidedAttrs = {
4222 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4231 if (!hasUserDefinedMaps())
4234 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4241 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4246 void BatchMatmulOp::getEffects(
4249 if (hasPureTensorSemantics())
4263 struct ArityGroupAndKind {
4275 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4281 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4282 constexpr
int lastBinary =
4283 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4284 constexpr
int lastTernary =
4285 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4287 int val =
static_cast<int>(
kind);
4288 ArityGroupAndKind result;
4290 if (val < lastUnary) {
4291 result.arityGroup = ElementwiseArityGroup::Unary;
4292 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4295 if (val < lastBinary) {
4296 result.arityGroup = ElementwiseArityGroup::Binary;
4297 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4300 if (val >= lastTernary) {
4301 llvm_unreachable(
"unhandled ElementwiseFn");
4303 result.arityGroup = ElementwiseArityGroup::Ternary;
4304 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4309 auto rank = getResultRank();
4314 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4323 mlir::linalg::ElementwiseKind elemwiseKindVal;
4328 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4329 if (!elemwiseKindAttr)
4331 "expected ElementwiseKind attribute");
4332 elemwiseKindVal = elemwiseKindAttr.getValue();
4335 "expected operation 'kind' attribute");
4351 if (!isa<AffineMapAttr>(mapAttr))
4353 "expected affine map attribute");
4354 indexingMapsAttr.push_back(mapAttr);
4365 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4367 ElementwiseOp::getRegionBuilder())) {
4369 "unable to parse elemwise op");
4373 if (indexingMapsAttr.empty()) {
4377 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4380 "return type needs to be shaped type");
4381 auto numDims = shapedType.getRank();
4382 indexingMapsAttr = llvm::map_to_vector(
4383 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4400 unsigned numDims = getResultRank();
4403 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4407 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4408 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4423 void ElementwiseOp::regionBuilder(
4426 ElementwiseKind elemwiseKind;
4427 for (
auto attr : attrs) {
4429 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4430 assert(kindAttr &&
"op kind attribute incorrectly set");
4431 elemwiseKind = kindAttr.getValue();
4438 auto kind = groupAndKind.kind;
4441 emitError() <<
"Elementwise regionBuilder expects "
4442 << (getArityGroupAsUInt(
arityGroup) + 1) <<
" args, got "
4448 &&
"Elementwise regionBuilder number of block args mismatch");
4450 RegionBuilderHelper helper(b, block);
4454 if (
arityGroup == ElementwiseArityGroup::Unary) {
4457 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4461 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4466 assert(
false &&
"found unhandled category in elemwise");
4469 yields.push_back(result);
4470 helper.yieldOutputs(yields);
4473 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4478 void ElementwiseOp::getEffects(
4481 if (hasPureTensorSemantics())
4494 template <
typename OpTy,
typename>
4497 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4498 ? packOrUnPack.getDestType()
4499 : packOrUnPack.getSourceType();
4500 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4501 ? packOrUnPack.getSourceType()
4502 : packOrUnPack.getDestType();
4504 packedType.getShape().take_front(unpackedType.getRank()));
4505 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4527 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4529 .take_back(mixedTiles.size()),
4531 int64_t shape = std::get<0>(it);
4532 if (shape == ShapedType::kDynamic) {
4533 newMixedTileSizes.push_back(std::get<1>(it));
4540 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4542 newMixedTileSizes.push_back(
tile);
4545 "tile size and dim size don't match!");
4546 newMixedTileSizes.push_back(
4551 return newMixedTileSizes;
4554 template <
typename OpTy>
4555 static LogicalResult
4558 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4559 "applies to only pack or unpack operations");
4560 int64_t destRank = op.getDestRank();
4562 reifiedReturnShapes[0] =
4567 template <
typename OpTy>
4569 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4570 "applies to only pack or unpack operations");
4574 assert(tiles.size() == dimsToTile.size() &&
4575 "tiles must match indices of dimension to block");
4577 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4578 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4579 return dimAndTileMapping;
4582 template <
typename OpTy>
4584 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4585 "applies to only pack or unpack operations");
4588 unsigned dynamicValIndex = 0;
4589 for (int64_t staticTile : op.getStaticInnerTiles()) {
4590 if (ShapedType::isStatic(staticTile))
4593 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4595 return mixedInnerTiles;
4598 template <
typename OpTy>
4600 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4601 "applies to only pack or unpack operations");
4614 size_t dimsPosSize = dimsPos.size();
4615 if (dimsPosSize > rank)
4618 if (dimsPosSize != uniqued.size())
4620 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4621 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4630 sourceShape.size() == limitShape.size() &&
4631 "expected source shape rank, and limit of the shape to have same rank");
4632 return llvm::all_of(
4633 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4634 int64_t sourceExtent = std::get<0>(it);
4635 int64_t limit = std::get<1>(it);
4636 return ShapedType::isDynamic(sourceExtent) ||
4637 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4641 template <
typename OpTy>
4643 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4644 "applies to only pack or unpack operations");
4645 Operation *op = packOrUnPack.getOperation();
4654 if (hasZeros(mixedTiles))
4655 return op->
emitError(
"invalid zero tile factor");
4658 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4659 ? packOrUnPack.getSourceType()
4660 : packOrUnPack.getDestType();
4661 size_t unpackedRank = unpackedType.getRank();
4665 return op->
emitError(
"invalid inner_dims_pos vector");
4667 return op->
emitError(
"invalid outer_dims_perm vector");
4668 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4669 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4673 if (mixedTiles.size() > unpackedRank) {
4674 return op->
emitError(
"tiling factors must be less than or equal to the "
4675 "input rank for pack or output rank for unpack");
4679 "tiling factors must equal the number of dimensions to tile");
4682 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4683 ? packOrUnPack.getDestType()
4684 : packOrUnPack.getSourceType();
4685 size_t packedRank = packedType.getRank();
4687 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4688 if (expectedPackedRank != packedRank) {
4690 "packed rank != (unpacked rank + num tiling factors), got ")
4691 << packedRank <<
" != " << expectedPackedRank;
4697 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4698 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4699 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4700 return op->
emitError(
"the shape of output is not large enough to hold the "
4701 "packed data. Expected at least ")
4702 << expectedPackedType <<
", got " << packedType;
4705 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4707 [](std::tuple<int64_t, OpFoldResult> it) {
4708 int64_t shape = std::get<0>(it);
4709 if (Attribute attr =
4710 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4711 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4712 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4713 return shape == staticTileSize;
4715 return ShapedType::isDynamic(shape);
4717 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4718 "tiled dimension in the packed type");
4730 struct PackOrUnPackTransposeResult {
4737 template <
typename OpTy>
4738 static PackOrUnPackTransposeResult
4742 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4743 "applies to only pack or unpack operations");
4744 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4745 "some permutation must be non-empty");
4746 PackOrUnPackTransposeResult metadata;
4747 metadata.innerDimsPos =
4749 metadata.innerTiles =
4751 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4752 ? packOrUnPackOp.getSourceRank()
4753 : packOrUnPackOp.getDestRank();
4754 metadata.outerDimsPerm =
4755 packOrUnPackOp.getOuterDimsPerm().empty()
4756 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4758 if (!innerPermutation.empty()) {
4759 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4761 "invalid inner permutation");
4765 if (!outerPermutation.empty()) {
4766 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4768 "invalid outer permutation");
4778 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4779 setNameFn(getResult(),
"pack");
4785 std::optional<Value> paddingValue,
4788 "number of tile sizes specified must match the specified number of "
4789 "original dimensions to be tiled");
4793 build(builder, state, dest.
getType(), source, dest,
4794 paddingValue ? *paddingValue :
nullptr,
4820 ShapedType inputType = getSourceType();
4821 int64_t inputRank = inputType.getRank();
4822 return getDestType().getShape().take_front(inputRank);
4827 auto packedShape = getDestType().getShape();
4831 res.push_back(packedShape[index]);
4842 outputShape.take_front(inputShape.size()));
4845 "expected output and outer_dims_perm to have same size");
4850 if (ShapedType::isDynamic(inputShape[pos]))
4854 if (!constantTile) {
4855 if (ShapedType::isStatic(outputTileSizes[pos]) &&
4856 (inputShape[pos] % outputTileSizes[pos] != 0))
4858 }
else if (inputShape[pos] % (*constantTile) != 0) {
4872 auto paddingValue = getPaddingValue();
4875 return emitOpError(
"expected padding_value has ")
4876 << getSourceType().getElementType()
4877 <<
" but got: " << paddingValue.getType();
4880 if (!paddingValue &&
4881 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4882 getDestType().
getShape(), getOuterDimsPerm(),
4885 "invalid tile factor or output size provided. Only full tiles are "
4886 "supported when padding_value is not set");
4896 for (
auto o : ofrs) {
4898 if (llvm::dyn_cast_if_present<Value>(o))
4899 result.push_back(ShapedType::kDynamic);
4914 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4916 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4917 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4920 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4921 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4929 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4944 builder, loc, ceilDivExpr,
4945 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4949 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4960 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4961 if (ShapedType::isStatic(resultTypeShape[i]))
4972 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4994 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4995 if (ShapedType::isDynamic(value))
4996 mixedSizes.push_back(
4997 tensor::DimOp::create(b, loc, source, index).getResult());
5001 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
5002 int64_t dimPos = std::get<0>(it);
5004 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5007 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
5009 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5010 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5011 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5018 *
this, innerPermutation, outerPermutation);
5019 Value transposedDest =
5020 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5021 metadata.innerDimsPos, metadata.outerDimsPerm);
5022 return PackOp::create(b, loc, getSource(), transposedDest,
5023 metadata.innerDimsPos, metadata.innerTiles,
5024 getPaddingValue(), metadata.outerDimsPerm);
5028 template <
typename OpTy>
5030 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5031 "applies to only pack or unpack operations");
5032 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5034 : op.getSourceType();
5036 for (
auto [dimDest,
tile] : llvm::zip(
5037 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5039 if (!constTileSize || ShapedType::isDynamic(dimDest))
5046 if (getPaddingValue())
5061 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5063 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5075 auto packTiles = packOp.getMixedTiles();
5076 auto unPackTiles = unPackOp.getMixedTiles();
5077 if (packTiles.size() != unPackTiles.size())
5079 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5088 auto srcType = op.getSourceType();
5089 if (llvm::any_of(op.getInnerDimsPos(),
5090 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5092 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5094 return !PackOp::requirePaddingValue(
5095 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5096 op.getOuterDimsPerm(), op.getMixedTiles());
5103 bool changeNeeded =
false;
5104 srcShape.assign(packOp.getSourceType().getShape().begin(),
5105 packOp.getSourceType().getShape().end());
5106 destShape.assign(packOp.getDestType().getShape().begin(),
5107 packOp.getDestType().getShape().end());
5108 llvm::SmallSetVector<int64_t, 4> innerDims;
5109 innerDims.insert_range(packOp.getInnerDimsPos());
5111 if (!packOp.getOuterDimsPerm().empty())
5113 int srcRank = packOp.getSourceRank();
5114 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5115 if (innerDims.contains(i))
5118 int64_t destPos = i;
5119 if (!inverseOuterDimsPerm.empty())
5120 destPos = inverseOuterDimsPerm[srcPos];
5121 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5122 ShapedType::isDynamic(destShape[destPos])) {
5125 int64_t size = srcShape[srcPos];
5126 if (ShapedType::isDynamic(size))
5127 size = destShape[destPos];
5128 srcShape[srcPos] = size;
5129 destShape[destPos] = size;
5130 changeNeeded =
true;
5132 return changeNeeded;
5135 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5137 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5138 if (unPackOp.getSourceType() != packOp.getDestType())
5140 if (packOp.getPaddingValue() ||
5144 rewriter.
replaceOp(packOp, unPackOp.getSource());
5151 packOp.getPaddingValueMutable().clear();
5160 Value source = packOp.getSource();
5161 if (srcShape != packOp.getSourceType().getShape()) {
5162 auto newSrcType = packOp.getSourceType().clone(srcShape);
5164 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5166 Value dest = packOp.getDest();
5167 RankedTensorType originalResultType = packOp.getDestType();
5168 bool needUpdateDestType = (destShape != originalResultType.getShape());
5169 if (needUpdateDestType) {
5170 auto newDestType = packOp.getDestType().clone(destShape);
5172 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5175 packOp.getSourceMutable().assign(source);
5176 packOp.getDestMutable().assign(dest);
5177 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5180 if (needUpdateDestType) {
5183 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5192 template <
typename PackOrUnpackOp>
5194 RankedTensorType packedTensorType) {
5195 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5196 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5197 "Function meant for pack/unpack");
5203 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5210 int64_t packedRank = packedTensorType.getRank();
5220 return llvm::all_of(
5221 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5222 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5225 bool PackOp::isLikePad() {
5226 auto packedTensorType =
5227 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5232 std::optional<Attribute> paddingValue;
5233 if (
auto pad = adaptor.getPaddingValue())
5235 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5236 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5237 getDestType(), paddingValue))
5238 return reshapedSource;
5277 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5278 op.getInnerDimsPos(), newMixedTileSizes,
5279 op.getPaddingValue(), op.getOuterDimsPerm());
5280 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5283 Value oldResult = op.getResult();
5284 Value newResult = newOp.getResult();
5287 ? tensor::CastOp::create(rewriter, op->getLoc(),
5288 oldResult.
getType(), newResult)
5301 void UnPackOp::getAsmResultNames(
5303 setNameFn(getResult(),
"unpack");
5325 ShapedType destType = getDestType();
5326 int64_t destRank = destType.getRank();
5327 return getSourceType().getShape().take_front(destRank);
5332 auto packedShape = getSourceType().getShape();
5336 res.push_back(packedShape[index]);
5358 "number of tile sizes specified must match the specified number of "
5359 "original dimensions to be tiled");
5363 build(builder, state, dest.
getType(), source, dest,
5382 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5384 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5385 if (srcType.isDynamicDim(i))
5386 mixedSizes.push_back(
5387 tensor::DimOp::create(b, loc, source, i).getResult());
5389 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5392 applyPermutationToVector<OpFoldResult>(
5396 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5397 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5399 auto elemType = srcType.getElementType();
5400 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5404 Value transposedSource,
5408 *
this, innerPermutation, outerPermutation);
5409 return UnPackOp::create(b, loc, transposedSource, getDest(),
5410 metadata.innerDimsPos, metadata.innerTiles,
5411 metadata.outerDimsPerm);
5418 bool changeNeeded =
false;
5419 srcShape.assign(op.getSourceType().getShape().begin(),
5420 op.getSourceType().getShape().end());
5421 destShape.assign(op.getDestType().getShape().begin(),
5422 op.getDestType().getShape().end());
5423 llvm::SmallSetVector<int64_t, 4> innerDims;
5424 innerDims.insert_range(op.getInnerDimsPos());
5426 if (!op.getOuterDimsPerm().empty())
5428 int destRank = op.getDestRank();
5429 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5430 if (innerDims.contains(i))
5433 int64_t destPos = i;
5434 if (!inverseOuterDimsPerm.empty())
5435 srcPos = inverseOuterDimsPerm[destPos];
5436 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5437 ShapedType::isDynamic(destShape[destPos])) {
5440 int64_t size = srcShape[srcPos];
5441 if (ShapedType::isDynamic(size))
5442 size = destShape[destPos];
5443 srcShape[srcPos] = size;
5444 destShape[destPos] = size;
5445 changeNeeded =
true;
5447 return changeNeeded;
5450 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5453 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5454 if (packOp.getSourceType() != unPackOp.getDestType())
5456 if (packOp.getPaddingValue() ||
5460 rewriter.
replaceOp(unPackOp, packOp.getSource());
5464 if (
auto dstStyleOp =
5465 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5466 auto destValue = cast<OpResult>(unPackOp.getDest());
5467 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5469 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5473 if (unPackOp->hasOneUse()) {
5474 auto extractSliceUser =
5475 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5476 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5479 auto newDest = tensor::ExtractSliceOp::create(
5480 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5481 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5482 extractSliceUser.getMixedStrides());
5484 unPackOp.setDpsInitOperand(0, newDest);
5485 unPackOp.getResult().setType(newDest.
getType());
5487 rewriter.
replaceOp(extractSliceUser, unPackOp);
5496 Value source = unPackOp.getSource();
5497 if (srcShape != unPackOp.getSourceType().getShape()) {
5498 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5499 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5500 unPackOp.getSource());
5502 Value dest = unPackOp.getDest();
5503 if (destShape != unPackOp.getDestType().getShape()) {
5504 auto newDestType = unPackOp.getDestType().clone(destShape);
5505 dest = tensor::CastOp::create(rewriter, loc, newDestType,
5506 unPackOp.getDest());
5508 Value newOp = UnPackOp::create(
5509 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5510 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
5512 unPackOp, unPackOp.getResult().getType(), newOp);
5519 bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5521 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5526 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
5529 for (
auto [pos, tileSize] :
5530 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5531 if (unpackedTypeAfterFold.isDynamicDim(pos))
5533 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5535 if (ShapedType::isDynamic(tileSize))
5537 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5538 unpackedTypeAfterFold.getDimSize(pos);
5539 if (paddingSize >= tileSize)
5545 bool UnPackOp::isLikeUnPad() {
5546 RankedTensorType packedTensorType = getSourceType();
5551 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5552 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5554 return reshapedSource;
5583 Value sourceTensor = newOperands[0];
5587 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5593 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
5594 newOperands[1], op.getInnerDimsPos(),
5595 newMixedTileSizes, op.getOuterDimsPerm());
5596 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5599 Value oldResult = op.getResult();
5600 Value newResult = newOp.getResult();
5603 ? tensor::CastOp::create(rewriter, op->getLoc(),
5604 oldResult.
getType(), newResult)
5618 utils::IteratorType::reduction, utils::IteratorType::parallel,
5619 utils::IteratorType::parallel, utils::IteratorType::reduction};
5623 BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
5627 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
5628 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
5630 return indexingMaps;
5633 unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
5635 std::string BatchReduceMatmulOp::getLibraryCallName() {
5641 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5645 return defaultMaps != explicitMaps;
5655 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
5658 "Expected less than 3 result dim expr.");
5659 bool isValid =
false;
5660 enum Indices { batchPos, mPos, nPos, kPos };
5678 void BatchReduceMatmulOp::regionBuilder(
5682 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
5687 "BatchReduceMatmulOp regionBuilder expects 3 args");
5688 RegionBuilderHelper helper(b, block);
5693 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
5695 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
5696 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5698 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
5699 yields.push_back(addVal);
5700 helper.yieldOutputs(yields);
5716 if (!isa<AffineMapAttr>(mapAttr)) {
5718 "expected affine map attribute");
5720 indexingMapsAttr.push_back(mapAttr);
5730 if (indexingMapsAttr.empty()) {
5731 indexingMapsAttr = llvm::map_to_vector(
5732 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
5738 BatchReduceMatmulOp::getNumRegionArgs(),
5739 BatchReduceMatmulOp::getRegionBuilder());
5744 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
5747 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5748 p <<
" indexing_maps = [";
5749 llvm::interleaveComma(getIndexingMaps(), p,
5755 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
5764 if (!hasUserDefinedMaps())
5767 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
5773 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5777 void BatchReduceMatmulOp::getEffects(
5780 if (hasPureTensorSemantics())
5796 void LinalgDialect::getCanonicalizationPatterns(
5805 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override