39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SetOperations.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/InterleavedRange.h"
47 #include "llvm/Support/LogicalResult.h"
48 #include "llvm/Support/MathExtras.h"
49 #include "llvm/Support/raw_ostream.h"
59 auto type = cast<ShapedType>(v.
getType());
60 if (!type.isDynamicDim(dim))
65 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
68 .Case<MemRefType>([&](MemRefType t) ->
Value {
69 return memref::DimOp::create(builder, loc, v, dim);
80 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
81 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
84 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
85 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
101 llvm_unreachable(
"Expected MemRefType or TensorType");
106 auto shapedType = llvm::cast<ShapedType>(source.
getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
132 for (
auto containers : {inputTypes, outputTypes}) {
133 for (
auto t : containers) {
145 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
149 regionBuilder(b, *body, attrs,
emitError);
161 std::optional<TypeRange> resultTensorTypes,
168 if (!resultTensorTypes)
169 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
176 state.addAttributes(attributes);
178 "operandSegmentSizes",
180 static_cast<int32_t>(outputs.size())}));
183 Region ®ion = *state.addRegion();
185 state.attributes.getAttrs(), {},
190 std::optional<TypeRange> resultTensorTypes,
197 indexingMapsAttrVal =
201 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
207 std::optional<TypeRange> resultTensorTypes,
214 indexingMapsAttrVal =
218 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
220 attributes, regionBuilder);
224 std::optional<TypeRange> resultTensorTypes,
231 indexingMapsAttrVal =
235 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
237 attributes, regionBuilder);
246 bool addOperandSegmentSizes =
true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
276 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
282 if (addOperandSegmentSizes) {
291 attrs.
append(
"operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
304 std::optional<RegisteredOperationName> info =
308 return parser.emitError(attrsLoc)
309 <<
"'" << result.name.getStringRef() <<
"' op ";
320 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
321 if (!outputs.empty())
322 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
336 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
342 ParseResult result = success();
344 opBuilder, region, inputTypes, outputTypes, attrs,
363 unsigned numRegionArgs,
380 result.
addTypes(outputTensorsTypes);
382 std::unique_ptr<Region> region = std::make_unique<Region>();
394 if (resultTypes.empty())
439 class RegionBuilderHelper {
442 : builder(builder), block(block) {}
447 if (!isFloatingPoint(arg)) {
449 emitError() <<
"unsupported non numeric type";
452 llvm_unreachable(
"unsupported non numeric type");
455 builder.setInsertionPointToEnd(&block);
458 return math::ExpOp::create(builder, arg.
getLoc(), arg);
460 return math::LogOp::create(builder, arg.
getLoc(), arg);
462 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
464 return math::CeilOp::create(builder, arg.
getLoc(), arg);
466 return math::FloorOp::create(builder, arg.
getLoc(), arg);
468 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
469 case UnaryFn::reciprocal: {
471 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
476 return math::RoundOp::create(builder, arg.
getLoc(), arg);
478 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
480 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
484 return math::TanhOp::create(builder, arg.
getLoc(), arg);
486 return math::ErfOp::create(builder, arg.
getLoc(), arg);
489 emitError() <<
"unsupported unary function";
492 llvm_unreachable(
"unsupported unary function");
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
506 if (!allComplex && !allFloatingPoint && !allInteger) {
509 <<
"Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
514 llvm_unreachable(
"unsupported non numeric type");
517 builder.setInsertionPointToEnd(&block);
521 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
525 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
529 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 emitError() <<
"unsupported operation: sub with bools";
537 llvm_unreachable(
"unsupported operation: sub with bools");
539 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
546 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 emitError() <<
"unsupported operation: div with bools";
558 llvm_unreachable(
"unsupported operation: div with bools");
560 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
564 emitError() <<
"unsupported operation: unsigned div not on uint";
567 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
569 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
582 if (allFloatingPoint)
583 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
584 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
585 case BinaryFn::min_unsigned:
587 if (allFloatingPoint)
588 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
589 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
591 assert(allFloatingPoint);
592 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
595 emitError() <<
"unsupported binary function";
598 llvm_unreachable(
"unsupported binary function");
606 bool tailFloatingPoint =
607 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
608 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
610 builder.setInsertionPointToEnd(&block);
612 case TernaryFn::select:
613 if (!headBool && !(tailFloatingPoint || tailInteger))
614 llvm_unreachable(
"unsupported non numeric type");
615 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
618 emitError() <<
"unsupported ternary function";
621 llvm_unreachable(
"unsupported ternary function");
628 case TypeFn::cast_signed:
629 return cast(toType, operand,
false);
630 case TypeFn::cast_unsigned:
631 return cast(toType, operand,
true);
634 emitError() <<
"unsupported type conversion function";
637 llvm_unreachable(
"unsupported type conversion function");
642 builder.setInsertionPointToEnd(&block);
643 Location loc = builder.getUnknownLoc();
644 YieldOp::create(builder, loc, values);
647 Value constant(
const std::string &value) {
649 builder.setInsertionPointToEnd(&block);
650 Location loc = builder.getUnknownLoc();
652 return arith::ConstantOp::create(builder, loc,
653 ::cast<TypedAttr>(valueAttr));
656 Value index(int64_t dim) {
658 builder.setInsertionPointToEnd(&block);
659 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
662 Type getIntegerType(
unsigned width) {
676 builder.setInsertionPointToEnd(&block);
677 auto loc = operand.
getLoc();
678 if (isa<UnknownLoc>(loc)) {
688 bool isComplex(
Value value) {
689 return llvm::isa<ComplexType>(value.
getType());
691 bool isFloatingPoint(
Value value) {
692 return llvm::isa<FloatType>(value.
getType());
694 bool isInteger(
Value value) {
695 return llvm::isa<IntegerType>(value.
getType());
712 LogicalResult matchAndRewrite(CopyOp copyOp,
714 if (copyOp.getInputs() != copyOp.getOutputs())
716 if (copyOp.hasPureBufferSemantics())
719 rewriter.
replaceOp(copyOp, copyOp.getInputs());
729 results.
add<EraseSelfCopy>(context);
742 template <
typename TensorReshapeOp>
745 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
747 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
752 TensorReshapeOp newInit;
753 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
755 newInit = TensorReshapeOp::create(
756 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
757 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
758 reshapeOp.getStaticOutputShape());
760 newInit = TensorReshapeOp::create(
761 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
762 reshapeOp.getReassociation());
775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
777 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
783 Value padValue = padOp.getConstantPaddingValue();
784 if (!padValue || fillOp.value() != padValue)
790 padOp,
"failed to reify tensor.pad op result shape");
793 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
794 padOp.getResultType().getElementType());
796 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
799 if (replacement.
getType() != padOp.getResultType()) {
800 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
801 padOp.getResultType(), replacement);
811 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
814 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
816 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
820 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
825 Value firstDest = insertOp.getDest();
826 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
827 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
832 bool disjoint =
false;
833 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
836 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
837 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
838 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
842 int64_t prevStart = prevOp.getStaticOffset(i);
843 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
844 prevOp.getStaticStride(i);
845 int64_t nextStart = insertOp.getStaticOffset(i);
846 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
847 insertOp.getStaticStride(i);
848 if (prevEnd < nextStart || nextEnd < prevStart) {
856 firstDest = prevOp.getDest();
867 Value padValue = srcPadOp.getConstantPaddingValue();
868 if (!padValue || dstFillOp.value() != padValue)
884 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
886 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
889 RankedTensorType srcPadType = srcPadOp.getSourceType();
891 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
892 if (srcPadType.isDynamicDim(i)) {
894 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
897 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
902 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
903 newSizes, insertOp.getMixedStrides());
909 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
913 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
917 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
922 Value extractedScalar = fillOp.getInputs()[0];
925 rewriter.
replaceOp(extractOp, extractedScalar);
933 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
934 linalg::PackOp packOp) {
935 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
939 if (
auto paddingValue = packOp.getPaddingValue())
943 Value packOpDest = packOp.getDest();
947 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
957 LogicalResult matchAndRewrite(linalg::PackOp packOp,
959 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
962 rewriter.
replaceOp(packOp, fillOp.value().result());
971 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
973 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
976 copyOp.getOutputs());
979 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
981 fillOp.getOutputs());
992 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
994 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
996 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
997 transposeOp.getDpsInitOperand(0)->get());
1009 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1011 auto concatOperands = concatOp.getInputs();
1012 if (concatOperands.empty()) {
1016 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1025 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1027 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
1028 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1035 if (fillVal != firstFillVal)
1038 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1041 if (!llvm::all_of(concatOperands.drop_front(),
1042 isDefinedByCompatibleFillOp)) {
1044 concatOp,
"not all operands are defined by a compatible fill op");
1047 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1048 concatOp.getDim(), allOuts);
1050 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1059 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1060 FoldFillWithPack, FoldFillWithPad,
1061 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1062 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1063 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1076 for (
ValueRange container : {inputs, outputs}) {
1077 for (
Value v : container) {
1078 Type t = v.getType();
1079 blockArgTypes.push_back(
1081 blockArgLocs.push_back(v.getLoc());
1087 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1091 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1093 for (
Value v : getRegionInputArgs())
1095 for (
Value v : getRegionOutputArgs())
1096 setNameFn(v,
"out");
1099 void GenericOp::build(
1102 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1105 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1106 iteratorTypes, doc, libraryCall);
1110 inputs, outputs, bodyBuild);
1113 void GenericOp::build(
1117 StringRef libraryCall,
1120 build(builder, result, resultTensorTypes, inputs, outputs,
1125 return IteratorTypeAttr::get(builder.getContext(), iter);
1128 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1129 bodyBuild, attributes);
1132 void GenericOp::build(
1136 StringRef libraryCall,
1139 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1140 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1143 void GenericOp::build(
1149 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1151 "", bodyBuild, attributes);
1154 void GenericOp::build(
1160 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1163 "", bodyBuild, attributes);
1170 auto genericAttrNames = linalgTraitAttrNames();
1173 genericAttrNamesSet.insert_range(genericAttrNames);
1175 for (
auto attr : (*this)->getAttrs()) {
1176 if (attr.getName() == getIteratorTypesAttrName()) {
1177 auto iteratorTypes =
1178 llvm::cast<ArrayAttr>(attr.getValue())
1179 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1185 llvm::to_vector(llvm::map_range(
1186 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1190 genericAttrs.emplace_back(
1191 getIteratorTypesAttrName(),
1193 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1194 genericAttrs.push_back(attr);
1197 if (!genericAttrs.empty()) {
1199 p << genericDictAttr;
1205 genericAttrNames.push_back(
"operandSegmentSizes");
1206 genericAttrNamesSet.insert(genericAttrNames.back());
1208 bool hasExtraAttrs =
false;
1210 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1213 if (hasExtraAttrs) {
1220 if (!getRegion().empty()) {
1230 DictionaryAttr dictAttr;
1239 dictAttr.getValue().end());
1245 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1247 if (!iteratorTypes) {
1248 return parser.
emitError(attributeLocation)
1249 <<
"expected " << getIteratorTypesAttrName(result.
name)
1250 <<
" array attribute";
1255 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1256 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1257 if (!maybeIteratorType.has_value())
1259 <<
"unexpected iterator_type (" << s <<
")";
1261 iteratorTypeAttrs.push_back(
1278 std::unique_ptr<Region> region = std::make_unique<Region>();
1290 result.
addTypes(outputTensorsTypes);
1298 LinalgOp linalgOp) {
1299 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1300 if (!llvm::isa<MemRefType>(operand.
getType()))
1302 effects.emplace_back(
1307 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1308 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1310 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321 void GenericOp::getEffects(
1331 if (!linalgOp.hasPureTensorSemantics())
1351 template <
typename OpTy>
1355 LogicalResult matchAndRewrite(OpTy linalgOp,
1358 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1363 Block &body = linalgOp->getRegion(0).
front();
1364 if (!llvm::hasSingleElement(body))
1366 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1371 if (linalgOp.hasPureBufferSemantics()) {
1372 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1373 linalgOp.getDpsInputOperand(0)->get() !=
1374 linalgOp.getDpsInitOperand(0)->get()) {
1376 linalgOp,
"expected single input and output to be the same value");
1379 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1380 if (!yieldArg || yieldArg.getOwner() != &body) {
1382 "cannot fold fill-like op");
1389 if (!linalgOp.hasPureTensorSemantics()) {
1391 linalgOp,
"mixed semantics is not supported yet");
1398 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1399 if (!yieldArg || yieldArg.getOwner() != &body)
1401 unsigned argumentNumber = yieldArg.getArgNumber();
1402 Value returnedArg = linalgOp->getOperand(argumentNumber);
1403 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1407 if (returnType != resultType) {
1412 returnedArg = sparse_tensor::ConvertOp::create(
1413 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1415 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1418 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1419 resultType, returnedArg);
1422 returnedArgs.push_back(returnedArg);
1425 if (returnedArgs.size() != linalgOp->getNumResults())
1427 rewriter.
replaceOp(linalgOp, returnedArgs);
1436 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1458 for (
Type outputType : outputTypes) {
1459 if (llvm::isa<RankedTensorType>(outputType))
1473 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1475 for (
Value v : getRegionInputArgs())
1480 if (!getResults().empty())
1481 setNameFn(getResults().front(),
"mapped");
1488 build(builder, result,
TypeRange{}, inputs, init);
1493 if (llvm::isa<RankedTensorType>(initType))
1498 inputs, {}, bodyBuild);
1505 bool initFirst =
false) {
1510 for (
auto &operand : operands) {
1512 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1519 payloadOpOperands.push_back(block.
getArguments().back());
1520 for (
const auto &arg : block.
getArguments().drop_back())
1521 payloadOpOperands.push_back(arg);
1530 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1537 std::optional<OperationName> payloadOpName;
1541 if (
failed(operationName))
1545 payloadOpName = operationName.value();
1553 if (payloadOpName.has_value()) {
1596 for (
const auto &[operand, bbArg] :
1598 if (bbArg != operand)
1602 for (
const auto &[operand, bbArg] :
1604 if (bbArg != operand)
1611 return yieldOp.getNumOperands() == 1 &&
1612 yieldOp.getOperand(0).getDefiningOp() &&
1613 yieldOp.getOperand(0).getDefiningOp() == &payload;
1618 std::string attrToElide;
1620 for (
const auto &attr : payloadOp->
getAttrs()) {
1622 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1623 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1624 attrToElide = attr.getName().str();
1625 elidedAttrs.push_back(attrToElide);
1634 Block *mapper = getBody();
1643 if (!useShortForm) {
1649 [&](
auto arg) { p.printRegionArgument(arg); });
1658 auto *bodyBlock = getBody();
1659 auto blockArgs = bodyBlock->getArguments();
1662 if (getInputs().size() != blockArgs.size())
1663 return emitOpError() <<
"expects number of operands to match the arity of "
1665 << getInputs().size() <<
" and " << blockArgs.size();
1668 for (
const auto &[bbArgType, inputArg] :
1669 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1670 auto inputElemType =
1671 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1672 if (bbArgType != inputElemType) {
1673 return emitOpError() <<
"expected element type of input " << inputElemType
1674 <<
" to match bbArg type " << bbArgType;
1679 auto outputShape = getInit().getType().getShape();
1681 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1682 if (inputElemShape != outputShape) {
1683 return emitOpError() <<
"expected shape of input (" << inputElemShape
1684 <<
") to match shape of output (" << outputShape
1693 int64_t rank = getInit().getType().getRank();
1697 ArrayAttr MapOp::getIndexingMaps() {
1699 int64_t rank = getInit().getType().getRank();
1700 int64_t numIndexingMaps = getOperands().size();
1705 void MapOp::getEffects(
1719 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1721 for (
Value v : getRegionInputArgs())
1723 for (
Value v : getRegionOutputArgs())
1724 setNameFn(v,
"init");
1727 void ReduceOp::getAsmResultNames(
1729 if (!getResults().empty())
1730 setNameFn(getResults().front(),
"reduced");
1733 void ReduceOp::build(
1738 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1742 for (
Value init : inits) {
1744 if (llvm::isa<RankedTensorType>(initType))
1750 inputs, inits, bodyBuild);
1755 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1757 utils::IteratorType::parallel);
1758 for (int64_t reductionDim : getDimensions())
1759 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1760 return iteratorTypes;
1763 ArrayAttr ReduceOp::getIndexingMaps() {
1765 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1772 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1773 affineMaps.push_back(resultMap);
1777 void ReduceOp::getEffects(
1789 StringRef attributeName) {
1798 std::optional<OperationName> payloadOpName;
1802 if (
failed(operationName))
1806 payloadOpName = operationName.value();
1817 if (payloadOpName.has_value()) {
1837 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1841 Block *mapper = getBody();
1850 if (!useShortForm) {
1856 [&](
auto arg) { p.printRegionArgument(arg); });
1867 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1868 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1869 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1870 return emitOpError() <<
"expects all inputs to have the same shapes. "
1871 "Shape at input-index "
1873 <<
" is not equal to the shape at input-index 0.";
1876 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1877 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1878 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1879 return emitOpError() <<
"expects all outputs to have the same shapes. "
1880 "Shape at output-index "
1882 <<
" is not equal to the shape at output-index 0.";
1885 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1886 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1889 for (int64_t dimension : dimensionsRef) {
1890 if (dimension < 0 || dimension >= inputType.getRank()) {
1891 return emitOpError()
1892 <<
"dimensions for reduction should be in the range [0, "
1893 << inputType.getRank() - 1 <<
"].";
1895 dimensionsToReduce.insert(dimension);
1898 auto inputDims = inputType.getShape();
1899 auto initDims = initType.getShape();
1904 if (!dimensionsToReduce.count(en.index()))
1905 reducedInputDims.push_back(en.value());
1908 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1909 return emitOpError() <<
"number of dimensions after reduction "
1910 << reducedInputDims.size()
1911 <<
" doesn't match the init rank "
1912 << initType.getRank();
1915 if (reducedInputDims != initDims)
1916 return emitOpError() <<
"init dimensions [" << initDims
1917 <<
"] doesn't match input dimensions after reduction ["
1918 << reducedInputDims <<
"]";
1920 Block *block = getBody();
1922 return emitOpError()
1923 <<
"mismatching number of operands and block arguments";
1926 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1927 Type inputElementType =
1928 llvm::cast<ShapedType>(input.getType()).getElementType();
1929 if (inputElementType != bbArg.getType())
1930 return emitOpError()
1931 <<
"input element type " << inputElementType
1932 <<
" does not match corresponding block argument type "
1937 for (
auto [output, bbArg] : llvm::zip(
1938 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1939 auto outputElementType =
1940 llvm::cast<ShapedType>(output.getType()).getElementType();
1941 if (outputElementType != bbArg.getType())
1942 return emitOpError()
1943 <<
"output element type " << outputElementType
1944 <<
" does not match corresponding block argument type "
1960 linalg::YieldOp::create(b, loc, args[0]);
1975 if (llvm::isa<RankedTensorType>(initType))
2004 void TransposeOp::getAsmResultNames(
2006 if (!getResults().empty())
2007 setNameFn(getResults().front(),
"transposed");
2020 return emitOpError(
"permutation is not valid");
2022 auto inputType = getInput().getType();
2023 auto initType = getInit().getType();
2025 int64_t rank = inputType.getRank();
2027 if (rank != initType.getRank())
2028 return emitOpError() <<
"input rank " << rank
2029 <<
" does not match init rank " << initType.getRank();
2031 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2032 return emitOpError() <<
"size of permutation " << permutationRef.size()
2033 <<
" does not match the argument rank " << rank;
2035 auto inputDims = inputType.getShape();
2036 auto initDims = initType.getShape();
2038 for (int64_t i = 0; i < rank; ++i) {
2039 int64_t inputDim = inputDims[permutationRef[i]];
2040 int64_t initDim = initDims[i];
2042 if (inputDim != initDim) {
2043 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2044 <<
" doesn't match dim(input, permutation[" << i
2045 <<
"]) = " << inputDim;
2053 int64_t rank = getInit().getType().getRank();
2057 ArrayAttr TransposeOp::getIndexingMaps() {
2059 int64_t rank = getInit().getType().getRank();
2062 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2066 void TransposeOp::getEffects(
2076 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2079 if (!isa<TensorType>(getInput().
getType()))
2083 if (getPermutation().size() == 0) {
2084 result.push_back(getInput());
2089 result.push_back(getInput());
2102 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2103 if (!defTransposeOp)
2108 foldedPerms.reserve(perms.size());
2109 for (int64_t perm : perms)
2110 foldedPerms.push_back(defPerms[perm]);
2113 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2127 Value input = transposeOp.getInput();
2128 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2139 unsigned dimensionSize = dimensions.size();
2140 for (
unsigned i = 0; i < dimensionSize; ++i)
2141 resultDimensions.push_back(invertPerm[dimensions[i]]);
2144 Value broadcastInput = broadcastOp.getInput();
2145 Location loc = transposeOp.getLoc();
2148 auto broadcastInputTy =
2149 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2150 unsigned inputRank = broadcastInputTy.getRank();
2151 for (
unsigned i = 0; i < inputRank; ++i) {
2152 if (broadcastInputTy.isDynamicDim(i)) {
2153 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2157 broadcastInputTy.getDimSize(i)));
2162 Value transposeInit = tensor::EmptyOp::create(
2163 rewriter, transposeOp.getLoc(), transposeResultShapes,
2164 broadcastInputTy.getElementType());
2167 Value transposeResult =
2168 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2169 transposeInit, resultPerms)
2172 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2197 if (llvm::isa<RankedTensorType>(initType))
2226 void BroadcastOp::getAsmResultNames(
2228 if (!getResults().empty())
2229 setNameFn(getResults().front(),
"broadcasted");
2241 auto inputType = getInput().getType();
2242 auto initType = getInit().getType();
2244 int64_t inputRank = inputType.getRank();
2245 int64_t initRank = initType.getRank();
2247 auto inputShape = inputType.getShape();
2248 auto initShape = initType.getShape();
2250 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2251 return emitOpError() <<
"input rank plus added dimensions does not "
2252 "match init rank. input rank: "
2254 <<
", dimensions size: " << dimensionsRef.size()
2255 <<
", init rank: " << initRank;
2258 if (dim < 0 || dim >= initRank)
2259 return emitOpError() <<
"dimension " << idx
2260 <<
" is out of range. expected range: [0, "
2261 << initRank - 1 <<
"], got: " << dim;
2266 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2267 if (!llvm::is_contained(dimensionsRef, dim))
2268 dimMap.push_back(dim);
2271 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2274 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2275 return emitOpError() <<
"input dim " << inputDimIdx
2276 <<
" should match init dim " << initDimIdx
2277 <<
". input: " << inputShape[inputDimIdx]
2278 <<
", init: " << initShape[initDimIdx];
2285 int64_t rank = getInit().getType().getRank();
2289 ArrayAttr BroadcastOp::getIndexingMaps() {
2291 int64_t rank = getInit().getType().getRank();
2297 void BroadcastOp::getEffects(
2313 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2314 if (!defBroadcastOp)
2319 Value init = broadcastOp.getInit();
2320 int64_t initRank = cast<ShapedType>(init.
getType()).getRank();
2323 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2324 if (!llvm::is_contained(dimensions, dim))
2325 dimMap.push_back(dim);
2327 for (
auto dim : defDimensions)
2328 foldedDims.push_back(dimMap[dim]);
2330 llvm::sort(foldedDims);
2332 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2347 if (getNumOperands() > 0)
2348 p <<
' ' << getOperands();
2350 if (getNumOperands() > 0)
2351 p <<
" : " << getOperandTypes();
2366 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2367 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2368 return op.emitOpError(
"expected number of yield values (")
2369 << op.getNumOperands()
2370 <<
") to match the number of inits / outs operands of the enclosing "
2371 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2373 for (
OpOperand &opOperand : op->getOpOperands()) {
2375 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2377 if (isa<MemRefType, RankedTensorType>(elementType))
2379 if (opOperand.get().getType() != elementType)
2380 return op.emitOpError(
"type of yield operand ")
2381 << (opOperand.getOperandNumber() + 1) <<
" ("
2382 << opOperand.get().getType() <<
") doesn't match "
2383 <<
"the element type of the enclosing linalg.generic op ("
2384 << elementType <<
")";
2390 auto *parentOp = (*this)->getParentOp();
2391 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2392 return emitOpError(
"expected single non-empty parent region");
2394 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2397 return emitOpError(
"expected parent op with LinalgOp interface");
2405 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2407 return emitOpError(
"expected parent op with LinalgOp interface");
2408 if (linalgOp.getNumLoops() <= getDim())
2409 return emitOpError(
"expected dim (")
2410 << getDim() <<
") to be lower than the number of loops ("
2411 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2416 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2425 uint64_t dim = getDim();
2426 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2427 if (loopBounds[dim] == 1)
2435 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2437 #define GET_OP_CLASSES
2438 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2440 #define GET_OP_CLASSES
2441 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2442 #define GET_OP_CLASSES
2443 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2460 for (
unsigned i = 0; i < num; ++i)
2467 auto rangeA = llvm::make_range(a.begin(), a.end());
2468 auto rangeB = llvm::make_range(b.begin(), b.end());
2469 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2470 return llvm::to_vector<4>(concatRanges);
2474 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2476 for (
auto size : memref.getShape())
2483 if (
auto as = memref.getMemorySpace()) {
2484 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2485 ss <<
"as" << attr.getInt();
2491 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2494 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2507 assert(isa<LinalgOp>(op));
2509 std::string fun =
"";
2511 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2512 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2513 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2514 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2518 llvm::replace(name,
'.',
'_');
2519 llvm::raw_string_ostream ss(name);
2523 return std::string();
2538 LogicalResult matchAndRewrite(LinalgOp op,
2540 for (
OpOperand &opOperand : op->getOpOperands()) {
2544 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2547 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2558 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2561 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2566 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2573 if (castOp->getBlock() != linalgOp->getBlock())
2580 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2583 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2589 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2591 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2594 linalgOp.getDpsInits().end());
2595 outputOperands[resultNumber] = newOperand;
2596 newOperands.append(outputOperands.begin(), outputOperands.end());
2599 linalgOp->result_type_end());
2600 resultTypes[resultNumber] = resultType;
2601 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2604 Value castBack = tensor::CastOp::create(
2608 results[resultNumber] = castBack;
2620 if (linalgOp.isScalar(&opOperand))
2622 Value src = opOperand.get();
2623 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2624 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2632 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2633 Value castSource = castOp.getSource();
2634 auto castSourceType =
2635 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2636 if (castSourceType && castSourceType.hasStaticShape())
2637 sourceShape = castSourceType.getShape();
2643 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2644 if (sourceType.isDynamicDim(i))
2646 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2647 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2657 static void createNewOperandWithStaticSizes(
2661 bool &changeNeeded) {
2663 newOperands.push_back(src);
2664 if (linalgOp.isScalar(opOperand))
2666 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2667 Type resultType = sourceType;
2668 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2669 resultTypes.push_back(resultType);
2673 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2677 bool newOperandNeeded =
false;
2678 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2679 int64_t dimShape = sourceShape[i];
2681 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2682 newShape.push_back(dimShape);
2688 newShape.push_back(affineExprToSize[dimExpr]);
2689 newOperandNeeded =
true;
2692 sourceType.getEncoding());
2693 if (newOperandNeeded) {
2694 changeNeeded =
true;
2697 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2699 newOperands[index] = newOperand;
2701 if (linalgOp.isDpsInit(opOperand))
2702 resultTypes.push_back(resultType);
2711 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2713 if (!linalgOp.hasPureTensorSemantics())
2717 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2718 return !map.isProjectedPermutation();
2728 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2735 bool changeNeeded =
false;
2736 newOperands.reserve(linalgOp->getNumOperands());
2737 resultTypes.reserve(linalgOp.getNumDpsInits());
2740 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2741 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2742 affineExprToSize, linalgOp, newOperands,
2743 resultTypes, changeNeeded);
2752 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2755 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2756 Value newResult = std::get<1>(it);
2757 Value oldResult = std::get<0>(it);
2760 replacements.push_back(
2761 (newType != oldType)
2762 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2765 rewriter.
replaceOp(linalgOp, replacements);
2780 ShapedType inputType = getInputOperandType();
2781 ShapedType outputType = getOutputOperandType();
2786 return emitOpError(
"incompatible output shape");
2788 int64_t inputRank = getInputOperandRank();
2789 int64_t dimension = getDimension();
2790 if ((dimension < 0) || (dimension >= inputRank))
2791 return emitOpError(
"incorrect dimension specified");
2797 int64_t operandRank = getInputOperandRank();
2802 Value source = getInput();
2803 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2804 loopBounds[dim].offset = zero;
2805 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2806 loopBounds[dim].stride = one;
2813 utils::IteratorType::parallel);
2814 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2815 return iteratorTypes;
2818 FailureOr<TilingResult>
2822 int64_t rank = getInputOperandRank();
2827 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2829 return emitOpError(
"failed to compute input slice");
2831 tiledOperands.emplace_back(inputSlice->
getResult(0));
2833 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2835 return emitOpError(
"failed to compute output slice");
2837 tiledOperands.emplace_back(outputSlice->
getResult(0));
2840 if (hasPureTensorSemantics())
2841 resultTypes.push_back(tiledOperands[1].
getType());
2843 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2855 if (resultNumber == 0) {
2856 resultOffsets.assign(offsets.begin(), offsets.end());
2857 resultSizes.assign(sizes.begin(), sizes.end());
2872 Location loc = getOperation()->getLoc();
2874 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2875 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2876 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2877 if (!outputShapedType.isDynamicDim(dim)) {
2879 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2886 reifiedReturnShapes.emplace_back(std::move(shapes));
2890 void SoftmaxOp::getEffects(
2894 if (!llvm::isa<MemRefType>(operand.
getType()))
2897 &getOperation()->getOpOperand(index), 0,
2902 for (
OpOperand &operand : getDpsInitsMutable()) {
2903 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2936 int64_t dim,
bool allParallel =
false) {
2938 utils::IteratorType::parallel);
2940 iteratorTypes[dim] = utils::IteratorType::reduction;
2944 for (
int i = 0; i < inputRank; i++) {
2951 return std::make_tuple(iteratorTypes, indexingMaps);
2956 template <
typename T>
2959 auto inputType = cast<ShapedType>(input.
getType());
2961 int64_t inputRank = inputShape.size();
2962 auto [iteratorTypes, indexingMaps] =
2964 assert(indexingMaps.size() == 2 &&
2965 "We should have two maps: 1 for the input, 1 for the output");
2966 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2968 auto genericOp = linalg::GenericOp::create(
2969 builder, loc, output.
getType(), input, output, indexingMaps,
2971 Value result = T::create(b, loc, args[0], args[1]);
2972 linalg::YieldOp::create(b, loc, result);
2974 return genericOp.getResult(0);
2982 auto inputType = cast<ShapedType>(input.
getType());
2984 int64_t inputRank = inputShape.size();
2986 builder, inputRank, dim,
true);
2987 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2988 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2990 indexingMaps.push_back(indexingMaps[0]);
2991 auto genericOp = linalg::GenericOp::create(
2993 indexingMaps, iteratorTypes,
2995 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
2996 Value result = math::ExpOp::create(b, loc, diff);
2997 linalg::YieldOp::create(b, loc, result);
2999 return genericOp.getResult(0);
3008 Value denominator,
Value output, int64_t dim) {
3009 auto inputType = cast<ShapedType>(numerator.
getType());
3011 int64_t inputRank = inputShape.size();
3013 builder, inputRank, dim,
true);
3014 assert(indexingMaps.size() == 2 &&
3015 "We should have one map for each input (2)");
3016 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3018 indexingMaps.push_back(indexingMaps[0]);
3019 auto genericOp = linalg::GenericOp::create(
3021 output, indexingMaps, iteratorTypes,
3023 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3024 linalg::YieldOp::create(b, loc, result);
3026 return genericOp.getResult(0);
3048 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
3052 Value input = getInput();
3053 ShapedType inputType = getInputOperandType();
3054 Type elementType = inputType.getElementType();
3055 int64_t reductionDim = getDimension();
3057 Value output = getOutput();
3058 dims.erase(dims.begin() + reductionDim);
3060 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3062 elementType, b, loc,
3064 Value neutralForMaxFInit =
3065 linalg::FillOp::create(b, loc,
Value{neutralForMaxF}, outputReduce)
3068 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3077 linalg::FillOp::create(b, loc,
Value{zero}, outputReduce).result();
3079 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3083 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3092 auto filterType = cast<ShapedType>(getFilter().
getType());
3094 int64_t filterH = filterShape[getFilterHDim()];
3095 int64_t filterW = filterShape[getFilterWDim()];
3096 WinogradConv2DFmr fmr = getFmr();
3100 if (filterH != r && filterH != 1)
3101 return emitOpError(
"expect filter height either equals to r or 1");
3102 if (filterW != r && filterW != 1)
3103 return emitOpError(
"expect filter width either equals to r or 1");
3104 if (filterH == 1 && filterW == 1)
3105 return emitOpError(
"expect either filter height or width equals to r");
3108 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3109 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3110 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3111 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3113 auto outputType = cast<ShapedType>(getOutput().
getType());
3116 return emitOpError(
"the output shape is not expected");
3122 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
3126 Value filter = getFilter();
3127 int64_t filterRank = getFilterOperandRank();
3129 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3130 loopBounds[dim].offset = zeroAttr;
3131 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3132 loopBounds[dim].stride = oneAttr;
3138 WinogradFilterTransformOp::getLoopIteratorTypes() {
3139 int64_t filterRank = getFilterOperandRank();
3141 utils::IteratorType::parallel);
3142 return iteratorTypes;
3150 ShapedType filterType = getFilterOperandType();
3152 int64_t filterH = filterShape[getFilterHDim()];
3153 int64_t filterW = filterShape[getFilterWDim()];
3154 WinogradConv2DFmr fmr = getFmr();
3157 int64_t alpha = m + r - 1;
3158 int64_t alphaH = filterH != 1 ? alpha : 1;
3159 int64_t alphaW = filterW != 1 ? alpha : 1;
3163 resultOffsets.append(
3164 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3166 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3182 ShapedType filterType = getFilterOperandType();
3184 int64_t filterH = filterShape[getFilterHDim()];
3185 int64_t filterW = filterShape[getFilterWDim()];
3191 sliceOffsets.append(
3192 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3193 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3194 sizes[getFilterCDim()]});
3195 int64_t filterRank = getFilterOperandRank();
3198 auto filterSlice = tensor::ExtractSliceOp::create(
3199 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3200 tiledOperands.emplace_back(filterSlice);
3207 int64_t outputRank = getOutputOperandRank();
3209 auto outputSlice = tensor::ExtractSliceOp::create(
3210 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3211 tiledOperands.emplace_back(outputSlice);
3214 resultTypes.push_back(tiledOperands[1].
getType());
3216 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3229 auto inputType = cast<ShapedType>(getInput().
getType());
3231 int64_t inputH = inputShape[getInputHDim()];
3232 int64_t inputW = inputShape[getInputWDim()];
3233 WinogradConv2DFmr fmr = getFmr();
3236 int64_t tileSize = m + r - 1;
3238 auto outputType = cast<ShapedType>(getOutput().
getType());
3240 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3241 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3244 if (ShapedType::isDynamic(inputH)) {
3245 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3246 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3248 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3249 expectedOutputShape[getOutputTileHDim()] =
3250 leftTransform ? (inputH - (r - 1)) / m : inputH;
3252 if (ShapedType::isDynamic(inputW)) {
3253 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3254 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3256 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3257 expectedOutputShape[getOutputTileWDim()] =
3258 rightTransform ? (inputW - (r - 1)) / m : inputW;
3260 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3261 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3264 return emitOpError(
"the output shape is not expected");
3270 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3274 Value output = getOutput();
3275 int64_t outputRank = getOutputOperandRank();
3277 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3278 loopBounds[dim].offset = zeroAttr;
3280 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3281 loopBounds[dim].stride = oneAttr;
3287 WinogradInputTransformOp::getLoopIteratorTypes() {
3288 int64_t outputRank = getOutputOperandRank();
3290 utils::IteratorType::parallel);
3291 return iteratorTypes;
3299 ShapedType outputType = getOutputOperandType();
3301 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3302 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3304 WinogradConv2DFmr fmr = getFmr();
3307 int64_t alpha = m + r - 1;
3308 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3309 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3314 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3315 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3316 offsets[getOutputCDim()]});
3317 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3318 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3319 sizes[getOutputCDim()]});
3330 FailureOr<TilingResult>
3335 WinogradConv2DFmr fmr = getFmr();
3339 ShapedType outputType = getOutputOperandType();
3341 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3342 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3346 auto identityAffineMap =
3348 auto offsetAffineMap =
3351 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3352 offsets[getOutputTileHDim()]);
3354 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3355 offsets[getOutputTileWDim()]);
3359 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3361 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3368 sliceOffsets.append(
3369 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3375 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3376 int64_t inputRank = getInputOperandRank();
3378 auto inputSlice = tensor::ExtractSliceOp::create(
3379 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3380 tiledOperands.emplace_back(inputSlice);
3387 int64_t outputRank = getOutputOperandRank();
3389 auto outputSlice = tensor::ExtractSliceOp::create(
3390 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3391 tiledOperands.emplace_back(outputSlice);
3394 resultTypes.push_back(tiledOperands[1].
getType());
3396 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3409 auto valueType = cast<ShapedType>(getValue().
getType());
3411 int64_t valueH = valueShape[getValueAlphaHDim()];
3412 int64_t valueW = valueShape[getValueAlphaWDim()];
3413 int64_t valueTileH = valueShape[getValueTileHDim()];
3414 int64_t valueTileW = valueShape[getValueTileWDim()];
3415 WinogradConv2DFmr fmr = getFmr();
3418 bool leftTransform = valueH != 1;
3419 bool rightTransform = valueW != 1;
3421 int64_t outputRank = getOutputOperandRank();
3423 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3424 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3426 if (valueH != (leftTransform ? m + r - 1 : 1))
3427 return emitOpError(
"expect input height equals to input tile size");
3428 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3430 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3431 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3433 if (valueW != (rightTransform ? m + r - 1 : 1))
3434 return emitOpError(
"expect input width equals to input tile size");
3435 expectedOutputShape[getOutputWDim()] =
3436 (rightTransform ? m : 1) * valueTileW;
3438 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3439 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3441 auto outputType = cast<ShapedType>(getOutput().
getType());
3444 return emitOpError(
"the output shape is not expected");
3450 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3454 Value value = getValue();
3455 int64_t valueRank = getValueOperandRank();
3457 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3458 loopBounds[dim].offset = zeroAttr;
3460 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3461 loopBounds[dim].stride = oneAttr;
3467 WinogradOutputTransformOp::getLoopIteratorTypes() {
3468 int64_t valueRank = getValueOperandRank();
3470 utils::IteratorType::parallel);
3471 return iteratorTypes;
3478 WinogradConv2DFmr fmr = getFmr();
3484 auto identityAffineMap =
3489 ShapedType valueType = getValueOperandType();
3491 int64_t valueH = valueShape[0];
3492 int64_t valueW = valueShape[1];
3494 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3495 offsets[getValueTileHDim()]);
3497 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3498 offsets[getValueTileWDim()]);
3500 builder, loc, affineMap, sizes[getValueTileHDim()]);
3502 builder, loc, affineMap, sizes[getValueTileWDim()]);
3512 resultOffsets.append(
3513 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3515 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3534 ShapedType valueType = getValueOperandType();
3536 int64_t alphaH = valueShape[getValueAlphaHDim()];
3537 int64_t alphaW = valueShape[getValueAlphaWDim()];
3541 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3542 offsets[getValueTileWDim()], offsets[getValueNDim()],
3543 offsets[getValueFDim()]});
3544 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3545 sizes[getValueTileWDim()], sizes[getValueNDim()],
3546 sizes[getValueFDim()]});
3547 int64_t valueRank = getValueOperandRank();
3549 auto valueSlice = tensor::ExtractSliceOp::create(
3550 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3551 tiledOperands.emplace_back(valueSlice);
3558 int64_t outputRank = getOutputOperandRank();
3560 auto outputSlice = tensor::ExtractSliceOp::create(
3561 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3562 tiledOperands.emplace_back(outputSlice);
3565 resultTypes.push_back(tiledOperands[1].
getType());
3567 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3586 llvm::set_union(explicitSet, defaultSet);
3587 return explicitSet == defaultSet;
3607 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3609 auto opIndexingMap = opIndexingMaps[opIndex];
3610 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3613 return matmulOp->emitOpError()
3614 <<
"Unexpected dim expression in map result.";
3617 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3618 return matmulOp->emitOpError()
3619 <<
"Invalid broadcast requested, should be (d2).";
3628 template <
typename OpTy>
3631 AffineMap defaultIndexingMap,
bool isLHS) {
3632 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3633 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3634 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3637 return batchVariantMatmulOp->emitOpError()
3638 <<
"Unexpected result dim expression (outside the set of default "
3643 return batchVariantMatmulOp->emitOpError()
3644 <<
"no. of result dim expressions exceeds 3.";
3646 auto hasValidBatchDim = [](
AffineMap map) {
3653 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3654 return batchVariantMatmulOp->emitOpError()
3655 <<
"Invalid broadcast requested.";
3656 }
else if (!hasValidBatchDim(opIndexingMap)) {
3657 return batchVariantMatmulOp->emitOpError()
3658 <<
"Invalid batch dimension expression.";
3666 template <
typename OpTy>
3669 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3670 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3671 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3672 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3675 return batchVariantMatmulOp->emitOpError()
3676 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3679 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3681 return batchVariantMatmulOp->emitOpError()
3682 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3686 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3687 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3688 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3689 outputMap.getResult(1).isFunctionOfDim(1) &&
3690 outputMap.getResult(2).isFunctionOfDim(2)
3691 : outputMap.getResult(0).isFunctionOfDim(1) &&
3692 outputMap.getResult(1).isFunctionOfDim(2);
3695 if (!areValidOutputResultDim(opIndexingMap)) {
3696 return batchVariantMatmulOp->emitOpError()
3697 <<
"Invalid output map result dimension.";
3706 template <
typename OpTy>
3707 static LogicalResult
3711 batchVariantMatmulOp.getIndexingMapsArray();
3713 batchVariantMatmulOp.getDefaultIndexingMaps(
3714 batchVariantMatmulOp->getContext());
3716 if (opIndexingMaps.size() != 3)
3717 return batchVariantMatmulOp->emitOpError()
3718 <<
"Indexing_map attribute must have 3 affine maps.";
3720 auto opIndexingMap = opIndexingMaps[opIndex];
3721 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3729 defaultIndexingMap, opIndex == 0)))
3739 if (m == 2 && r == 3)
3740 return WinogradConv2DFmr::F_2_3;
3741 if (m == 4 && r == 3)
3742 return WinogradConv2DFmr::F_4_3;
3743 if (m == 2 && r == 5)
3744 return WinogradConv2DFmr::F_2_5;
3745 return std::nullopt;
3750 case WinogradConv2DFmr::F_2_3:
3752 case WinogradConv2DFmr::F_4_3:
3754 case WinogradConv2DFmr::F_2_5:
3763 static FailureOr<SmallVector<SmallVector<int64_t>>>
3766 for (
auto map : maps) {
3767 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3771 for (
auto result : attr.getAffineMap().getResults()) {
3772 auto dim = dyn_cast<AffineDimExpr>(result);
3775 pos.push_back(dim.getPosition());
3777 positions.push_back(pos);
3790 return indexingMaps;
3793 bool MatmulOp::isDefaultIndexingMaps(
Attribute attr) {
3794 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3797 if (maps.size() != 3)
3809 utils::IteratorType::parallel,
3810 utils::IteratorType::reduction};
3813 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3815 std::string MatmulOp::getLibraryCallName() {
3819 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3823 bool MatmulOp::hasUserDefinedMaps() {
3827 return defaultMaps != explicitMaps;
3836 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3841 "MatmulOp regionBuilder expects 3 args");
3842 RegionBuilderHelper helper(b, block);
3845 TypeFn castVal = TypeFn::cast_signed;
3846 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3847 return attr.
getName() ==
"cast";
3849 if (castIter != attrs.end()) {
3850 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3858 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3865 yields.push_back(value4);
3866 helper.yieldOutputs(yields);
3876 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3877 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3888 ArrayAttr arrayAttr;
3892 if (llvm::any_of(arrayAttr,
3893 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3895 <<
"element of indexing_maps array is not an affine_map";
3902 if (
failed(indexingMapsAttr))
3905 if (*indexingMapsAttr ==
nullptr) {
3906 auto indexingMapAttrs = llvm::map_to_vector(
3907 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3912 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3914 MatmulOp::getRegionBuilder());
3919 MatmulOp::getDefaultIndexingMaps(
getContext()),
3921 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3922 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3924 std::array<StringRef, 3> elidedAttrs = {
3925 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3933 if (!hasUserDefinedMaps())
3936 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3947 void MatmulOp::getEffects(
3950 if (hasPureTensorSemantics())
3960 MatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
3967 return {mapLHS, mapRHS, mapOut};
3971 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3974 if (maps.size() != 3)
3988 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3989 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
3997 build(builder, state, inputs, outputs, attributes);
3998 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
3999 assert(res &&
"builder didn't return the right type");
4008 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4009 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4018 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4019 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4020 assert(res &&
"builder didn't return the right type");
4031 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4032 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4041 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4042 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4043 assert(res &&
"builder didn't return the right type");
4048 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4050 op->
getAttr(
"indexing_maps"));
4054 MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4061 return {mapLHS, mapRHS, mapOut};
4065 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4068 if (maps.size() != 3)
4082 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4083 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4091 build(builder, state, inputs, outputs, attributes);
4092 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4093 assert(res &&
"builder didn't return the right type");
4102 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4103 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4112 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4113 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4114 assert(res &&
"builder didn't return the right type");
4125 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4126 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4135 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4136 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4137 assert(res &&
"builder didn't return the right type");
4142 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4144 op->
getAttr(
"indexing_maps"));
4148 BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4155 return {mapLHS, mapRHS, mapOut};
4159 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4162 if (maps.size() != 3)
4175 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4176 BatchMatmulOp::getRegionBuilder(),
4177 getDefaultIndexingMaps(builder));
4185 build(builder, state, inputs, outputs, attributes);
4186 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4187 assert(res &&
"builder didn't return the right type");
4195 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4196 BatchMatmulOp::getRegionBuilder(),
4197 getDefaultIndexingMaps(builder));
4206 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4207 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4208 assert(res &&
"builder didn't return the right type");
4217 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4218 BatchMatmulOp::getRegionBuilder(),
4219 getDefaultIndexingMaps(builder));
4228 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4229 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4230 assert(res &&
"builder didn't return the right type");
4235 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4237 op->
getAttr(
"indexing_maps"));
4241 BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4248 return {mapLHS, mapRHS, mapOut};
4252 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4255 if (maps.size() != 3)
4268 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4269 BatchMatmulOp::getRegionBuilder(),
4270 getDefaultIndexingMaps(builder));
4278 build(builder, state, inputs, outputs, attributes);
4279 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4280 assert(res &&
"builder didn't return the right type");
4288 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4289 BatchMatmulOp::getRegionBuilder(),
4290 getDefaultIndexingMaps(builder));
4299 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4300 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4301 assert(res &&
"builder didn't return the right type");
4310 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4311 BatchMatmulOp::getRegionBuilder(),
4312 getDefaultIndexingMaps(builder));
4321 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4322 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4323 assert(res &&
"builder didn't return the right type");
4328 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4330 op->
getAttr(
"indexing_maps"));
4338 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4348 for (
auto result : outAffineMap.
getResults()) {
4349 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4350 assert(dimExpr &&
"affine_map is a projected permutation");
4351 dimsInOutput[dimExpr.getPosition()] =
true;
4355 for (
auto dimOccursInOutput : dimsInOutput)
4356 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4357 : utils::IteratorType::reduction);
4359 return iteratorTypes;
4362 unsigned ContractOp::getNumRegionArgs() {
return 3; }
4369 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4374 "ContractOp regionBuilder expects 3 args");
4375 RegionBuilderHelper helper(b, block);
4377 TypeFn castSignedness = TypeFn::cast_signed;
4378 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4379 return attr.
getName() ==
"cast";
4381 if (castIter != attrs.end()) {
4382 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4388 Value lhsAtOutType =
4389 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4390 Value rhsAtOutType =
4391 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4392 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4394 if (!productAtOutType)
4400 helper.yieldOutputs({result});
4405 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4407 "expected 'indexing_maps' attribute");
4408 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
4415 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4417 p, getOperation(), getInputs(), getOutputs(),
4418 {
"indexing_maps",
"operandSegmentSizes"});
4422 int iterationSpaceDims = -1;
4431 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
4432 bool isInput) -> LogicalResult {
4435 return emitError(
"provided affine_map is not a projected permutation");
4438 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4440 return emitError(
"ranks of shaped operand and results of corresponding "
4441 "affine_map differ");
4443 return emitError(
"affine_map specifies shaped access while operand has "
4448 if (iterationSpaceDims == -1) {
4452 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4453 return emitError(
"iteration spaces of provided affine_maps differ");
4458 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4460 llvm_unreachable(
"affine_map is a projected permutation");
4463 inOccurrences[affineDimExpr.getPosition()] += 1;
4465 outOccurrences[affineDimExpr.getPosition()] += 1;
4471 for (
auto &&[affineMap, operandType, isInput] :
4472 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4474 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4478 bool hasContractingDim =
false;
4479 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4480 size_t inOccCount = inOccurrences[dimIndex];
4481 size_t outOccCount = outOccurrences[dimIndex];
4484 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4486 if (inOccCount == 0 && outOccCount == 0)
4487 return emitError() <<
"iteration space dim at index " << dimIndex
4488 <<
" not used to access any operand";
4499 if (inOccCount == 1 && outOccCount != 1)
4501 <<
"iteration space dim at index " << dimIndex
4502 <<
" is neither a contracting dim nor of parallel iteration type";
4505 if (!hasContractingDim)
4506 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4515 void ContractOp::getEffects(
4518 if (hasPureTensorSemantics())
4531 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
4535 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4536 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4537 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4538 return indexingMaps;
4541 bool BatchMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
4542 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4545 if (maps.size() != 3)
4557 utils::IteratorType::parallel, utils::IteratorType::parallel,
4558 utils::IteratorType::parallel, utils::IteratorType::reduction};
4561 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4563 std::string BatchMatmulOp::getLibraryCallName() {
4569 bool BatchMatmulOp::hasUserDefinedMaps() {
4573 return defaultMaps != explicitMaps;
4583 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
4585 "Expected less than 3 result dim expr.");
4586 bool isValid =
false;
4587 enum Indices { batchPos, mPos, nPos, kPos };
4605 void BatchMatmulOp::regionBuilder(
4609 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4614 "BatchMatmulOp regionBuilder expects 3 args");
4615 RegionBuilderHelper helper(b, block);
4618 TypeFn castVal = TypeFn::cast_signed;
4619 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4620 return attr.
getName() ==
"cast";
4622 if (castIter != attrs.end()) {
4623 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4628 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4629 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4630 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4633 yields.push_back(addVal);
4634 helper.yieldOutputs(yields);
4650 if (!isa<AffineMapAttr>(mapAttr)) {
4652 "expected affine map attribute");
4654 indexingMapsAttr.push_back(mapAttr);
4664 if (indexingMapsAttr.empty()) {
4665 indexingMapsAttr = llvm::map_to_vector(
4666 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4673 BatchMatmulOp::getNumRegionArgs(),
4674 BatchMatmulOp::getRegionBuilder());
4679 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4681 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4682 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4684 std::array<StringRef, 3> elidedAttrs = {
4685 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4694 if (!hasUserDefinedMaps())
4697 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4704 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4709 void BatchMatmulOp::getEffects(
4712 if (hasPureTensorSemantics())
4726 struct ArityGroupAndKind {
4738 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4744 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4745 constexpr
int lastBinary =
4746 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4747 constexpr
int lastTernary =
4748 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4750 int val =
static_cast<int>(
kind);
4751 ArityGroupAndKind result;
4753 if (val < lastUnary) {
4754 result.arityGroup = ElementwiseArityGroup::Unary;
4755 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4758 if (val < lastBinary) {
4759 result.arityGroup = ElementwiseArityGroup::Binary;
4760 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4763 if (val >= lastTernary) {
4764 llvm_unreachable(
"unhandled ElementwiseFn");
4766 result.arityGroup = ElementwiseArityGroup::Ternary;
4767 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4772 auto rank = getResultRank();
4777 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4786 mlir::linalg::ElementwiseKind elemwiseKindVal;
4791 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4792 if (!elemwiseKindAttr)
4794 "expected ElementwiseKind attribute");
4795 elemwiseKindVal = elemwiseKindAttr.getValue();
4798 "expected operation 'kind' attribute");
4814 if (!isa<AffineMapAttr>(mapAttr))
4816 "expected affine map attribute");
4817 indexingMapsAttr.push_back(mapAttr);
4828 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4830 ElementwiseOp::getRegionBuilder())) {
4832 "unable to parse elemwise op");
4836 if (indexingMapsAttr.empty()) {
4840 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4843 "return type needs to be shaped type");
4844 auto numDims = shapedType.getRank();
4845 indexingMapsAttr = llvm::map_to_vector(
4846 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4863 unsigned numDims = getResultRank();
4866 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4870 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4871 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4886 void ElementwiseOp::regionBuilder(
4889 ElementwiseKind elemwiseKind;
4890 for (
auto attr : attrs) {
4892 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4893 assert(kindAttr &&
"op kind attribute incorrectly set");
4894 elemwiseKind = kindAttr.getValue();
4901 auto kind = groupAndKind.kind;
4904 emitError() <<
"Elementwise regionBuilder expects "
4905 << (getArityGroupAsUInt(
arityGroup) + 1) <<
" args, got "
4911 &&
"Elementwise regionBuilder number of block args mismatch");
4913 RegionBuilderHelper helper(b, block);
4917 if (
arityGroup == ElementwiseArityGroup::Unary) {
4920 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4924 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4929 assert(
false &&
"found unhandled category in elemwise");
4932 yields.push_back(result);
4933 helper.yieldOutputs(yields);
4936 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4941 void ElementwiseOp::getEffects(
4944 if (hasPureTensorSemantics())
4957 template <
typename OpTy,
typename>
4960 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4961 ? packOrUnPack.getDestType()
4962 : packOrUnPack.getSourceType();
4963 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4964 ? packOrUnPack.getSourceType()
4965 : packOrUnPack.getDestType();
4967 packedType.getShape().take_front(unpackedType.getRank()));
4968 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4990 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4992 .take_back(mixedTiles.size()),
4994 int64_t shape = std::get<0>(it);
4995 if (shape == ShapedType::kDynamic) {
4996 newMixedTileSizes.push_back(std::get<1>(it));
5003 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5005 newMixedTileSizes.push_back(
tile);
5008 "tile size and dim size don't match!");
5009 newMixedTileSizes.push_back(
5014 return newMixedTileSizes;
5017 template <
typename OpTy>
5018 static LogicalResult
5021 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5022 "applies to only pack or unpack operations");
5023 int64_t destRank = op.getDestRank();
5025 reifiedReturnShapes[0] =
5030 template <
typename OpTy>
5032 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5033 "applies to only pack or unpack operations");
5037 assert(tiles.size() == dimsToTile.size() &&
5038 "tiles must match indices of dimension to block");
5040 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5041 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5042 return dimAndTileMapping;
5045 template <
typename OpTy>
5047 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5048 "applies to only pack or unpack operations");
5051 unsigned dynamicValIndex = 0;
5052 for (int64_t staticTile : op.getStaticInnerTiles()) {
5053 if (ShapedType::isStatic(staticTile))
5056 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5058 return mixedInnerTiles;
5061 template <
typename OpTy>
5063 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5064 "applies to only pack or unpack operations");
5077 size_t dimsPosSize = dimsPos.size();
5078 if (dimsPosSize > rank)
5081 if (dimsPosSize != uniqued.size())
5083 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5084 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
5088 template <
typename OpTy>
5090 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5091 "applies to only pack or unpack operations");
5092 Operation *op = packOrUnPack.getOperation();
5101 if (hasZeros(mixedTiles))
5102 return op->
emitError(
"invalid zero tile factor");
5105 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5106 ? packOrUnPack.getSourceType()
5107 : packOrUnPack.getDestType();
5108 size_t unpackedRank = unpackedType.getRank();
5112 return op->
emitError(
"invalid inner_dims_pos vector");
5114 return op->
emitError(
"invalid outer_dims_perm vector");
5115 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5116 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5120 if (mixedTiles.size() > unpackedRank) {
5121 return op->
emitError(
"tiling factors must be less than or equal to the "
5122 "input rank for pack or output rank for unpack");
5126 "tiling factors must equal the number of dimensions to tile");
5129 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5130 ? packOrUnPack.getDestType()
5131 : packOrUnPack.getSourceType();
5132 size_t packedRank = packedType.getRank();
5134 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5135 if (expectedPackedRank != packedRank) {
5137 "packed rank != (unpacked rank + num tiling factors), got ")
5138 << packedRank <<
" != " << expectedPackedRank;
5144 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5145 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
5147 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5149 [](std::tuple<int64_t, OpFoldResult> it) {
5150 int64_t shape = std::get<0>(it);
5151 if (Attribute attr =
5152 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5153 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5154 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5155 return shape == staticTileSize;
5157 return ShapedType::isDynamic(shape);
5159 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
5160 "tiled dimension in the packed type");
5163 packedType.getShape()))) {
5164 return op->emitError(
"expected ")
5165 << expectedPackedType <<
" for the packed domain value, got "
5178 struct PackOrUnPackTransposeResult {
5185 template <
typename OpTy>
5186 static PackOrUnPackTransposeResult
5190 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5191 "applies to only pack or unpack operations");
5192 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5193 "some permutation must be non-empty");
5194 PackOrUnPackTransposeResult metadata;
5195 metadata.innerDimsPos =
5197 metadata.innerTiles =
5199 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5200 ? packOrUnPackOp.getSourceRank()
5201 : packOrUnPackOp.getDestRank();
5202 metadata.outerDimsPerm =
5203 packOrUnPackOp.getOuterDimsPerm().empty()
5204 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5206 if (!innerPermutation.empty()) {
5207 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5209 "invalid inner permutation");
5213 if (!outerPermutation.empty()) {
5214 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5216 "invalid outer permutation");
5226 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
5227 setNameFn(getResult(),
"pack");
5233 std::optional<Value> paddingValue,
5236 "number of tile sizes specified must match the specified number of "
5237 "original dimensions to be tiled");
5241 build(builder, state, dest.
getType(), source, dest,
5242 paddingValue ? *paddingValue :
nullptr,
5268 ShapedType inputType = getSourceType();
5269 int64_t inputRank = inputType.getRank();
5270 return getDestType().getShape().take_front(inputRank);
5275 auto packedShape = getDestType().getShape();
5279 res.push_back(packedShape[index]);
5290 outputShape.take_front(inputShape.size()));
5293 "expected output and outer_dims_perm to have same size");
5298 if (ShapedType::isDynamic(inputShape[pos]))
5302 if (!constantTile) {
5303 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5304 (inputShape[pos] % outputTileSizes[pos] != 0))
5306 }
else if (inputShape[pos] % (*constantTile) != 0) {
5320 auto paddingValue = getPaddingValue();
5323 return emitOpError(
"expected padding_value has ")
5324 << getSourceType().getElementType()
5325 <<
" but got: " << paddingValue.getType();
5328 if (!paddingValue &&
5329 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5330 getDestType().
getShape(), getOuterDimsPerm(),
5333 "invalid tile factor or output size provided. Only full tiles are "
5334 "supported when padding_value is not set");
5344 for (
auto o : ofrs) {
5346 if (llvm::dyn_cast_if_present<Value>(o))
5347 result.push_back(ShapedType::kDynamic);
5362 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5364 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5365 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5368 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5369 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5377 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5392 builder, loc, ceilDivExpr,
5393 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5397 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5408 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5409 if (ShapedType::isStatic(resultTypeShape[i]))
5420 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5442 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5443 if (ShapedType::isDynamic(value))
5444 mixedSizes.push_back(
5445 tensor::DimOp::create(b, loc, source, index).getResult());
5449 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
5450 int64_t dimPos = std::get<0>(it);
5452 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5455 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
5457 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5458 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5459 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5466 *
this, innerPermutation, outerPermutation);
5467 Value transposedDest =
5468 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5469 metadata.innerDimsPos, metadata.outerDimsPerm);
5470 return PackOp::create(b, loc, getSource(), transposedDest,
5471 metadata.innerDimsPos, metadata.innerTiles,
5472 getPaddingValue(), metadata.outerDimsPerm);
5476 template <
typename OpTy>
5478 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5479 "applies to only pack or unpack operations");
5480 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5482 : op.getSourceType();
5484 for (
auto [dimDest,
tile] : llvm::zip(
5485 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5487 if (!constTileSize || ShapedType::isDynamic(dimDest))
5494 if (getPaddingValue())
5509 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5511 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5523 auto packTiles = packOp.getMixedTiles();
5524 auto unPackTiles = unPackOp.getMixedTiles();
5525 if (packTiles.size() != unPackTiles.size())
5527 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5536 auto srcType = op.getSourceType();
5537 if (llvm::any_of(op.getInnerDimsPos(),
5538 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5540 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5542 return !PackOp::requirePaddingValue(
5543 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5544 op.getOuterDimsPerm(), op.getMixedTiles());
5551 bool changeNeeded =
false;
5552 srcShape.assign(packOp.getSourceType().getShape().begin(),
5553 packOp.getSourceType().getShape().end());
5554 destShape.assign(packOp.getDestType().getShape().begin(),
5555 packOp.getDestType().getShape().end());
5556 llvm::SmallSetVector<int64_t, 4> innerDims;
5557 innerDims.insert_range(packOp.getInnerDimsPos());
5559 if (!packOp.getOuterDimsPerm().empty())
5561 int srcRank = packOp.getSourceRank();
5562 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5563 if (innerDims.contains(i))
5566 int64_t destPos = i;
5567 if (!inverseOuterDimsPerm.empty())
5568 destPos = inverseOuterDimsPerm[srcPos];
5569 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5570 ShapedType::isDynamic(destShape[destPos])) {
5573 int64_t size = srcShape[srcPos];
5574 if (ShapedType::isDynamic(size))
5575 size = destShape[destPos];
5576 srcShape[srcPos] = size;
5577 destShape[destPos] = size;
5578 changeNeeded =
true;
5580 return changeNeeded;
5583 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5585 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5586 if (unPackOp.getSourceType() != packOp.getDestType())
5588 if (packOp.getPaddingValue() ||
5592 rewriter.
replaceOp(packOp, unPackOp.getSource());
5599 packOp.getPaddingValueMutable().clear();
5608 Value source = packOp.getSource();
5609 if (srcShape != packOp.getSourceType().getShape()) {
5610 auto newSrcType = packOp.getSourceType().clone(srcShape);
5612 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5614 Value dest = packOp.getDest();
5615 RankedTensorType originalResultType = packOp.getDestType();
5616 bool needUpdateDestType = (destShape != originalResultType.getShape());
5617 if (needUpdateDestType) {
5618 auto newDestType = packOp.getDestType().clone(destShape);
5620 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5623 packOp.getSourceMutable().assign(source);
5624 packOp.getDestMutable().assign(dest);
5625 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5628 if (needUpdateDestType) {
5631 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5640 template <
typename PackOrUnpackOp>
5642 RankedTensorType packedTensorType) {
5643 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5644 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5645 "Function meant for pack/unpack");
5651 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5658 int64_t packedRank = packedTensorType.getRank();
5668 return llvm::all_of(
5669 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5670 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5673 bool PackOp::isLikePad() {
5674 auto packedTensorType =
5675 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5680 std::optional<Attribute> paddingValue;
5681 if (
auto pad = adaptor.getPaddingValue())
5683 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5684 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5685 getDestType(), paddingValue))
5686 return reshapedSource;
5725 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5726 op.getInnerDimsPos(), newMixedTileSizes,
5727 op.getPaddingValue(), op.getOuterDimsPerm());
5728 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5731 Value oldResult = op.getResult();
5732 Value newResult = newOp.getResult();
5735 ? tensor::CastOp::create(rewriter, op->getLoc(),
5736 oldResult.
getType(), newResult)
5749 void UnPackOp::getAsmResultNames(
5751 setNameFn(getResult(),
"unpack");
5773 ShapedType destType = getDestType();
5774 int64_t destRank = destType.getRank();
5775 return getSourceType().getShape().take_front(destRank);
5786 if (!outerDimPermInv.empty())
5791 res.push_back(outerDims[index]);
5813 "number of tile sizes specified must match the specified number of "
5814 "original dimensions to be tiled");
5818 build(builder, state, dest.
getType(), source, dest,
5837 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5839 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5840 if (srcType.isDynamicDim(i))
5841 mixedSizes.push_back(
5842 tensor::DimOp::create(b, loc, source, i).getResult());
5844 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5847 applyPermutationToVector<OpFoldResult>(
5851 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5852 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5854 auto elemType = srcType.getElementType();
5855 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5859 Value transposedSource,
5863 *
this, innerPermutation, outerPermutation);
5864 return UnPackOp::create(b, loc, transposedSource, getDest(),
5865 metadata.innerDimsPos, metadata.innerTiles,
5866 metadata.outerDimsPerm);
5873 bool changeNeeded =
false;
5874 srcShape.assign(op.getSourceType().getShape().begin(),
5875 op.getSourceType().getShape().end());
5876 destShape.assign(op.getDestType().getShape().begin(),
5877 op.getDestType().getShape().end());
5878 llvm::SmallSetVector<int64_t, 4> innerDims;
5879 innerDims.insert_range(op.getInnerDimsPos());
5881 if (!op.getOuterDimsPerm().empty())
5883 int destRank = op.getDestRank();
5884 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5885 if (innerDims.contains(i))
5888 int64_t destPos = i;
5889 if (!inverseOuterDimsPerm.empty())
5890 srcPos = inverseOuterDimsPerm[destPos];
5891 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5892 ShapedType::isDynamic(destShape[destPos])) {
5895 int64_t size = srcShape[srcPos];
5896 if (ShapedType::isDynamic(size))
5897 size = destShape[destPos];
5898 srcShape[srcPos] = size;
5899 destShape[destPos] = size;
5900 changeNeeded =
true;
5902 return changeNeeded;
5905 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5908 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5909 if (packOp.getSourceType() != unPackOp.getDestType())
5911 if (packOp.getPaddingValue() ||
5915 rewriter.
replaceOp(unPackOp, packOp.getSource());
5919 if (
auto dstStyleOp =
5920 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5921 auto destValue = cast<OpResult>(unPackOp.getDest());
5922 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5924 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5928 if (unPackOp->hasOneUse()) {
5929 auto extractSliceUser =
5930 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5931 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5934 auto newDest = tensor::ExtractSliceOp::create(
5935 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5936 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5937 extractSliceUser.getMixedStrides());
5939 unPackOp.setDpsInitOperand(0, newDest);
5940 unPackOp.getResult().setType(newDest.
getType());
5942 rewriter.
replaceOp(extractSliceUser, unPackOp);
5951 Value source = unPackOp.getSource();
5952 if (srcShape != unPackOp.getSourceType().getShape()) {
5953 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5954 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5955 unPackOp.getSource());
5957 Value dest = unPackOp.getDest();
5958 if (destShape != unPackOp.getDestType().getShape()) {
5959 auto newDestType = unPackOp.getDestType().clone(destShape);
5960 dest = tensor::CastOp::create(rewriter, loc, newDestType,
5961 unPackOp.getDest());
5963 Value newOp = UnPackOp::create(
5964 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5965 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
5967 unPackOp, unPackOp.getResult().getType(), newOp);
5974 bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5976 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5981 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
5984 for (
auto [pos, tileSize] :
5985 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5986 if (unpackedTypeAfterFold.isDynamicDim(pos))
5988 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5990 if (ShapedType::isDynamic(tileSize))
5992 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5993 unpackedTypeAfterFold.getDimSize(pos);
5994 if (paddingSize >= tileSize)
6000 bool UnPackOp::isLikeUnPad() {
6001 RankedTensorType packedTensorType = getSourceType();
6006 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6007 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6009 return reshapedSource;
6038 Value sourceTensor = newOperands[0];
6042 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6048 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6049 newOperands[1], op.getInnerDimsPos(),
6050 newMixedTileSizes, op.getOuterDimsPerm());
6051 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6054 Value oldResult = op.getResult();
6055 Value newResult = newOp.getResult();
6058 ? tensor::CastOp::create(rewriter, op->getLoc(),
6059 oldResult.
getType(), newResult)
6073 utils::IteratorType::reduction, utils::IteratorType::parallel,
6074 utils::IteratorType::parallel, utils::IteratorType::reduction};
6078 BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
6082 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6083 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6085 return indexingMaps;
6088 bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
6089 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6092 if (maps.size() != 3)
6101 unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6103 std::string BatchReduceMatmulOp::getLibraryCallName() {
6109 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6113 return defaultMaps != explicitMaps;
6123 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
6126 "Expected less than 3 result dim expr.");
6127 bool isValid =
false;
6128 enum Indices { batchPos, mPos, nPos, kPos };
6146 void BatchReduceMatmulOp::regionBuilder(
6150 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6155 "BatchReduceMatmulOp regionBuilder expects 3 args");
6156 RegionBuilderHelper helper(b, block);
6161 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6163 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6164 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6167 yields.push_back(addVal);
6168 helper.yieldOutputs(yields);
6184 if (!isa<AffineMapAttr>(mapAttr)) {
6186 "expected affine map attribute");
6188 indexingMapsAttr.push_back(mapAttr);
6198 if (indexingMapsAttr.empty()) {
6199 indexingMapsAttr = llvm::map_to_vector(
6200 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6206 BatchReduceMatmulOp::getNumRegionArgs(),
6207 BatchReduceMatmulOp::getRegionBuilder());
6212 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6215 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6216 p <<
" indexing_maps = [";
6217 llvm::interleaveComma(getIndexingMaps(), p,
6223 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6232 if (!hasUserDefinedMaps())
6235 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6241 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6245 void BatchReduceMatmulOp::getEffects(
6248 if (hasPureTensorSemantics())
6264 void LinalgDialect::getCanonicalizationPatterns(
6273 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static bool canUseShortForm(Block *body, bool initFirst=false)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override