41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SetOperations.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/ADT/StringSet.h"
47 #include "llvm/ADT/TypeSwitch.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/LogicalResult.h"
50 #include "llvm/Support/MathExtras.h"
51 #include "llvm/Support/raw_ostream.h"
61 auto type = cast<ShapedType>(v.
getType());
62 if (!type.isDynamicDim(dim))
67 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
68 return builder.create<tensor::DimOp>(loc, v, dim);
70 .Case<MemRefType>([&](MemRefType t) ->
Value {
71 return builder.create<memref::DimOp>(loc, v, dim);
82 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
83 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
87 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
103 llvm_unreachable(
"Expected MemRefType or TensorType");
108 auto shapedType = llvm::cast<ShapedType>(source.
getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
132 for (
auto containers : {inputTypes, outputTypes}) {
133 for (
auto t : containers) {
145 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
149 regionBuilder(b, *body, attrs);
161 std::optional<TypeRange> resultTensorTypes,
168 if (!resultTensorTypes)
169 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
176 state.addAttributes(attributes);
178 "operandSegmentSizes",
180 static_cast<int32_t>(outputs.size())}));
183 Region ®ion = *state.addRegion();
185 state.attributes.getAttrs(), regionBuilder);
189 std::optional<TypeRange> resultTensorTypes,
196 indexingMapsAttrVal = llvm::map_to_vector(
197 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
199 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
201 attributes, regionBuilder);
205 std::optional<TypeRange> resultTensorTypes,
212 indexingMapsAttrVal =
216 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
218 attributes, regionBuilder);
227 bool addOperandSegmentSizes =
true) {
228 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
257 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
259 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
263 if (addOperandSegmentSizes) {
272 attrs.
append(
"operandSegmentSizes",
274 {static_cast<int32_t>(inputsOperands.size()),
275 static_cast<int32_t>(outputsOperands.size())}));
280 {static_cast<int32_t>(inputsOperands.size()),
281 static_cast<int32_t>(outputsOperands.size())}));
285 std::optional<RegisteredOperationName> info =
288 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
289 return parser.emitError(attrsLoc)
290 <<
"'" << result.name.getStringRef() <<
"' op ";
301 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
302 if (!outputs.empty())
303 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
314 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
317 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
318 "region expects {0} args, got {1}",
319 numRegionArgs, inputTypes.size() + outputTypes.size()));
338 unsigned numRegionArgs,
354 result.
addTypes(outputTensorsTypes);
356 std::unique_ptr<Region> region = std::make_unique<Region>();
368 if (resultTypes.empty())
413 class RegionBuilderHelper {
416 : builder(builder), block(block) {}
419 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
420 if (!isFloatingPoint(arg))
421 llvm_unreachable(
"unsupported non numeric type");
423 builder.setInsertionPointToEnd(&block);
426 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
428 return builder.create<math::LogOp>(arg.
getLoc(), arg);
430 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
432 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
434 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
436 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
437 case UnaryFn::reciprocal: {
439 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
440 ::cast<TypedAttr>(oneAttr));
441 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
444 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
446 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
448 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
449 case UnaryFn::square:
450 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
452 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
454 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
456 llvm_unreachable(
"unsupported unary function");
461 bool allComplex = isComplex(arg0) && isComplex(arg1);
462 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
463 bool allInteger = isInteger(arg0) && isInteger(arg1);
466 if (!allComplex && !allFloatingPoint && !allInteger)
467 llvm_unreachable(
"unsupported non numeric type");
469 builder.setInsertionPointToEnd(&block);
473 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
474 if (allFloatingPoint)
475 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
477 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
478 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
481 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
482 if (allFloatingPoint)
483 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
485 llvm_unreachable(
"unsupported operation: sub with bools");
486 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
489 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
490 if (allFloatingPoint)
491 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
493 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
494 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
497 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
498 if (allFloatingPoint)
499 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
501 llvm_unreachable(
"unsupported operation: div with bools");
502 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
503 case BinaryFn::div_unsigned:
504 if (!allInteger || allBool)
505 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
506 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
507 case BinaryFn::max_signed:
509 if (allFloatingPoint)
510 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
511 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
512 case BinaryFn::min_signed:
514 if (allFloatingPoint)
515 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
516 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
517 case BinaryFn::max_unsigned:
519 if (allFloatingPoint)
520 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
521 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
522 case BinaryFn::min_unsigned:
524 if (allFloatingPoint)
525 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
526 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
528 assert(allFloatingPoint);
529 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
531 llvm_unreachable(
"unsupported binary function");
539 bool tailFloatingPoint =
540 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
541 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
543 builder.setInsertionPointToEnd(&block);
545 case TernaryFn::select:
546 if (!headBool && !(tailFloatingPoint || tailInteger))
547 llvm_unreachable(
"unsupported non numeric type");
548 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
550 llvm_unreachable(
"unsupported ternary function");
556 case TypeFn::cast_signed:
557 return cast(toType, operand,
false);
558 case TypeFn::cast_unsigned:
559 return cast(toType, operand,
true);
561 llvm_unreachable(
"unsupported type conversion function");
566 builder.setInsertionPointToEnd(&block);
567 Location loc = builder.getUnknownLoc();
568 builder.create<YieldOp>(loc, values);
571 Value constant(
const std::string &value) {
573 builder.setInsertionPointToEnd(&block);
574 Location loc = builder.getUnknownLoc();
576 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
579 Value index(int64_t dim) {
581 builder.setInsertionPointToEnd(&block);
582 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
585 Type getIntegerType(
unsigned width) {
599 builder.setInsertionPointToEnd(&block);
600 auto loc = operand.
getLoc();
604 bool isComplex(
Value value) {
605 return llvm::isa<ComplexType>(value.
getType());
607 bool isFloatingPoint(
Value value) {
608 return llvm::isa<FloatType>(value.
getType());
610 bool isInteger(
Value value) {
611 return llvm::isa<IntegerType>(value.
getType());
628 LogicalResult matchAndRewrite(CopyOp copyOp,
630 if (copyOp.getInputs() != copyOp.getOutputs())
632 if (copyOp.hasPureBufferSemantics())
635 rewriter.
replaceOp(copyOp, copyOp.getInputs());
645 results.
add<EraseSelfCopy>(context);
658 template <
typename TensorReshapeOp>
661 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
663 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
668 TensorReshapeOp newInit;
669 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
671 newInit = rewriter.
create<TensorReshapeOp>(
672 loc, reshapeOp.getResultType(), oldFill.output(),
673 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
674 reshapeOp.getStaticOutputShape());
676 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
678 reshapeOp.getReassociation());
691 LogicalResult matchAndRewrite(tensor::PadOp padOp,
693 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
699 Value padValue = padOp.getConstantPaddingValue();
700 if (!padValue || fillOp.value() != padValue)
706 padOp,
"failed to reify tensor.pad op result shape");
708 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
709 padOp.getLoc(), reifiedShape.front(),
710 padOp.getResultType().getElementType());
716 if (replacement.getType() != padOp.getResultType()) {
717 replacement = rewriter.
create<tensor::CastOp>(
718 fillOp.getLoc(), padOp.getResultType(), replacement);
728 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
731 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
733 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
737 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
742 Value firstDest = insertOp.getDest();
743 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
744 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
749 bool disjoint =
false;
750 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
753 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
754 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
755 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
759 int64_t prevStart = prevOp.getStaticOffset(i);
760 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
761 prevOp.getStaticStride(i);
762 int64_t nextStart = insertOp.getStaticOffset(i);
763 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
764 insertOp.getStaticStride(i);
765 if (prevEnd < nextStart || nextEnd < prevStart) {
773 firstDest = prevOp.getDest();
784 Value padValue = srcPadOp.getConstantPaddingValue();
785 if (!padValue || dstFillOp.value() != padValue)
801 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
803 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
806 RankedTensorType srcPadType = srcPadOp.getSourceType();
808 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
809 if (srcPadType.isDynamicDim(i)) {
811 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
814 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
819 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
820 newSizes, insertOp.getMixedStrides());
826 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
830 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
834 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
839 Value extractedScalar = fillOp.getInputs()[0];
842 rewriter.
replaceOp(extractOp, extractedScalar);
850 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
851 linalg::PackOp packOp) {
852 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
856 if (
auto paddingValue = packOp.getPaddingValue())
860 Value packOpDest = packOp.getDest();
864 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
874 LogicalResult matchAndRewrite(linalg::PackOp packOp,
876 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
879 rewriter.
replaceOp(packOp, fillOp.value().result());
888 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
890 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
893 copyOp.getOutputs());
896 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
898 fillOp.getOutputs());
909 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
911 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
913 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
914 transposeOp.getDpsInitOperand(0)->get());
926 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
928 auto concatOperands = concatOp.getInputs();
929 if (concatOperands.empty()) {
933 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
942 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
944 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
945 auto fillOp = v.getDefiningOp<linalg::FillOp>();
952 if (fillVal != firstFillVal)
955 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
958 if (!llvm::all_of(concatOperands.drop_front(),
959 isDefinedByCompatibleFillOp)) {
961 concatOp,
"not all operands are defined by a compatible fill op");
964 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
965 concatOp.getLoc(), concatOp.getDim(), allOuts);
967 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
976 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
977 FoldFillWithPack, FoldFillWithPad,
978 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
979 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
980 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
993 for (
ValueRange container : {inputs, outputs}) {
994 for (
Value v : container) {
995 Type t = v.getType();
996 blockArgTypes.push_back(
998 blockArgLocs.push_back(v.getLoc());
1004 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1008 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1010 for (
Value v : getRegionInputArgs())
1012 for (
Value v : getRegionOutputArgs())
1013 setNameFn(v,
"out");
1016 void GenericOp::build(
1019 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1022 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1023 iteratorTypes, doc, libraryCall);
1027 inputs, outputs, bodyBuild);
1030 void GenericOp::build(
1034 StringRef libraryCall,
1037 build(builder, result, resultTensorTypes, inputs, outputs,
1042 return IteratorTypeAttr::get(builder.getContext(), iter);
1045 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1046 bodyBuild, attributes);
1049 void GenericOp::build(
1053 StringRef libraryCall,
1056 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1057 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1060 void GenericOp::build(
1066 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1068 "", bodyBuild, attributes);
1071 void GenericOp::build(
1077 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1080 "", bodyBuild, attributes);
1087 auto genericAttrNames = linalgTraitAttrNames();
1090 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1092 for (
auto attr : (*this)->getAttrs()) {
1093 if (attr.getName() == getIteratorTypesAttrName()) {
1094 auto iteratorTypes =
1095 llvm::cast<ArrayAttr>(attr.getValue())
1096 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1102 llvm::to_vector(llvm::map_range(
1103 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1107 genericAttrs.emplace_back(
1108 getIteratorTypesAttrName(),
1110 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1111 genericAttrs.push_back(attr);
1114 if (!genericAttrs.empty()) {
1116 p << genericDictAttr;
1122 genericAttrNames.push_back(
"operandSegmentSizes");
1123 genericAttrNamesSet.insert(genericAttrNames.back());
1125 bool hasExtraAttrs =
false;
1127 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1130 if (hasExtraAttrs) {
1137 if (!getRegion().empty()) {
1147 DictionaryAttr dictAttr;
1156 dictAttr.getValue().end());
1162 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1164 if (!iteratorTypes) {
1165 return parser.
emitError(attributeLocation)
1166 <<
"expected " << getIteratorTypesAttrName(result.
name)
1167 <<
" array attribute";
1172 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1173 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1174 if (!maybeIteratorType.has_value())
1176 <<
"unexpected iterator_type (" << s <<
")";
1178 iteratorTypeAttrs.push_back(
1195 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 result.
addTypes(outputTensorsTypes);
1215 LinalgOp linalgOp) {
1216 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1217 if (!llvm::isa<MemRefType>(operand.
getType()))
1219 effects.emplace_back(
1224 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1225 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1227 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1238 void GenericOp::getEffects(
1248 if (!linalgOp.hasPureTensorSemantics())
1267 template <
typename OpTy>
1271 LogicalResult matchAndRewrite(OpTy linalgOp,
1274 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1279 Block &body = linalgOp->getRegion(0).
front();
1280 if (!llvm::hasSingleElement(body))
1282 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1287 if (linalgOp.hasPureBufferSemantics()) {
1288 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1289 linalgOp.getDpsInputOperand(0)->get() ==
1290 linalgOp.getDpsInitOperand(0)->get()) {
1298 if (!linalgOp.hasPureTensorSemantics())
1305 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1306 if (!yieldArg || yieldArg.getOwner() != &body)
1308 unsigned argumentNumber = yieldArg.getArgNumber();
1309 Value returnedArg = linalgOp->getOperand(argumentNumber);
1310 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1314 if (returnType != resultType) {
1319 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1320 linalgOp.getLoc(), resultType, returnedArg);
1322 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1325 returnedArg = rewriter.
create<tensor::CastOp>(
1326 linalgOp.getLoc(), resultType, returnedArg);
1329 returnedArgs.push_back(returnedArg);
1332 if (returnedArgs.size() != linalgOp->getNumResults())
1334 rewriter.
replaceOp(linalgOp, returnedArgs);
1343 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1365 for (
Type outputType : outputTypes) {
1366 if (llvm::isa<RankedTensorType>(outputType))
1371 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1380 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1382 for (
Value v : getRegionInputArgs())
1387 if (!getResults().empty())
1388 setNameFn(getResults().front(),
"mapped");
1395 build(builder, result,
TypeRange{}, inputs, init);
1400 if (llvm::isa<RankedTensorType>(initType))
1405 inputs, {}, bodyBuild);
1412 bool initFirst =
false) {
1418 for (
auto &operand : operands) {
1420 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1427 payloadOpOperands.push_back(block.
getArguments().back());
1428 for (
const auto &arg : block.
getArguments().drop_back())
1429 payloadOpOperands.push_back(arg);
1438 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1445 std::optional<OperationName> payloadOpName;
1449 if (failed(operationName))
1453 payloadOpName = operationName.value();
1461 if (payloadOpName.has_value()) {
1499 for (
const auto &[operand, bbArg] :
1501 if (bbArg != operand)
1505 for (
const auto &[operand, bbArg] :
1507 if (bbArg != operand)
1516 std::string attrToElide;
1518 for (
const auto &attr : payloadOp->
getAttrs()) {
1520 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1521 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1522 attrToElide = attr.getName().str();
1523 elidedAttrs.push_back(attrToElide);
1532 Block *mapper = getBody();
1547 [&](
auto arg) { p.printRegionArgument(arg); });
1556 auto *bodyBlock = getBody();
1557 auto blockArgs = bodyBlock->getArguments();
1560 if (getInputs().size() != blockArgs.size())
1561 return emitOpError() <<
"expects number of operands to match the arity of "
1563 << getInputs().size() <<
" and " << blockArgs.size();
1566 for (
const auto &[bbArgType, inputArg] :
1567 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1568 auto inputElemType =
1569 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1570 if (bbArgType != inputElemType) {
1571 return emitOpError() <<
"expected element type of input " << inputElemType
1572 <<
" to match bbArg type " << bbArgType;
1577 auto outputShape = getInit().getType().getShape();
1579 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1580 if (inputElemShape != outputShape) {
1581 return emitOpError() <<
"expected shape of input (" << inputElemShape
1582 <<
") to match shape of output (" << outputShape
1591 int64_t rank = getInit().getType().getRank();
1595 ArrayAttr MapOp::getIndexingMaps() {
1597 int64_t rank = getInit().getType().getRank();
1598 int64_t numIndexingMaps = getOperands().size();
1603 void MapOp::getEffects(
1617 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1619 for (
Value v : getRegionInputArgs())
1621 for (
Value v : getRegionOutputArgs())
1622 setNameFn(v,
"init");
1625 void ReduceOp::getAsmResultNames(
1627 if (!getResults().empty())
1628 setNameFn(getResults().front(),
"reduced");
1631 void ReduceOp::build(
1636 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1640 for (
Value init : inits) {
1642 if (llvm::isa<RankedTensorType>(initType))
1648 inputs, inits, bodyBuild);
1653 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1655 utils::IteratorType::parallel);
1656 for (int64_t reductionDim : getDimensions())
1657 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1658 return iteratorTypes;
1661 ArrayAttr ReduceOp::getIndexingMaps() {
1663 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1670 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1671 affineMaps.push_back(resultMap);
1675 void ReduceOp::getEffects(
1687 StringRef attributeName) {
1696 std::optional<OperationName> payloadOpName;
1700 if (failed(operationName))
1704 payloadOpName = operationName.value();
1715 if (payloadOpName.has_value()) {
1735 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1739 Block *mapper = getBody();
1754 [&](
auto arg) { p.printRegionArgument(arg); });
1765 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1766 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1767 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1768 return emitOpError() <<
"expects all inputs to have the same shapes. "
1769 "Shape at input-index "
1771 <<
" is not equal to the shape at input-index 0.";
1774 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1775 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1776 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1777 return emitOpError() <<
"expects all outputs to have the same shapes. "
1778 "Shape at output-index "
1780 <<
" is not equal to the shape at output-index 0.";
1783 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1784 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1787 for (int64_t dimension : dimensionsRef) {
1788 if (dimension < 0 || dimension >= inputType.getRank()) {
1789 return emitOpError()
1790 <<
"dimensions for reduction should be in the range [0, "
1791 << inputType.getRank() - 1 <<
"].";
1793 dimensionsToReduce.insert(dimension);
1796 auto inputDims = inputType.getShape();
1797 auto initDims = initType.getShape();
1802 if (!dimensionsToReduce.count(en.index()))
1803 reducedInputDims.push_back(en.value());
1806 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1807 return emitOpError() <<
"number of dimensions after reduction "
1808 << reducedInputDims.size()
1809 <<
" doesn't match the init rank "
1810 << initType.getRank();
1813 if (reducedInputDims != initDims)
1814 return emitOpError() <<
"init dimensions [" << initDims
1815 <<
"] doesn't match input dimensions after reduction ["
1816 << reducedInputDims <<
"]";
1818 Block *block = getBody();
1820 return emitOpError()
1821 <<
"mismatching number of operands and block arguments";
1824 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1825 Type inputElementType =
1826 llvm::cast<ShapedType>(input.getType()).getElementType();
1827 if (inputElementType != bbArg.getType())
1828 return emitOpError()
1829 <<
"input element type " << inputElementType
1830 <<
" does not match corresponding block argument type "
1835 for (
auto [output, bbArg] : llvm::zip(
1836 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1837 auto outputElementType =
1838 llvm::cast<ShapedType>(output.getType()).getElementType();
1839 if (outputElementType != bbArg.getType())
1840 return emitOpError()
1841 <<
"output element type " << outputElementType
1842 <<
" does not match corresponding block argument type "
1858 b.
create<linalg::YieldOp>(loc, args[0]);
1873 if (llvm::isa<RankedTensorType>(initType))
1902 void TransposeOp::getAsmResultNames(
1904 if (!getResults().empty())
1905 setNameFn(getResults().front(),
"transposed");
1918 return emitOpError(
"permutation is not valid");
1920 auto inputType = getInput().getType();
1921 auto initType = getInit().getType();
1923 int64_t rank = inputType.getRank();
1925 if (rank != initType.getRank())
1926 return emitOpError() <<
"input rank " << rank
1927 <<
" does not match init rank " << initType.getRank();
1929 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1930 return emitOpError() <<
"size of permutation " << permutationRef.size()
1931 <<
" does not match the argument rank " << rank;
1933 auto inputDims = inputType.getShape();
1934 auto initDims = initType.getShape();
1936 for (int64_t i = 0; i < rank; ++i) {
1937 int64_t inputDim = inputDims[permutationRef[i]];
1938 int64_t initDim = initDims[i];
1940 if (inputDim != initDim) {
1941 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1942 <<
" doesn't match dim(input, permutation[" << i
1943 <<
"]) = " << inputDim;
1951 int64_t rank = getInit().getType().getRank();
1955 ArrayAttr TransposeOp::getIndexingMaps() {
1957 int64_t rank = getInit().getType().getRank();
1960 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1964 void TransposeOp::getEffects(
1974 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1977 if (!isa<TensorType>(getInput().
getType()))
1981 if (getPermutation().size() == 0) {
1982 result.push_back(getInput());
1987 result.push_back(getInput());
2000 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2001 if (!defTransposeOp)
2006 foldedPerms.reserve(perms.size());
2007 for (int64_t perm : perms)
2008 foldedPerms.push_back(defPerms[perm]);
2011 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2025 Value input = transposeOp.getInput();
2026 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2037 unsigned dimensionSize = dimensions.size();
2038 for (
unsigned i = 0; i < dimensionSize; ++i)
2039 resultDimensions.push_back(invertPerm[dimensions[i]]);
2042 Value broadcastInput = broadcastOp.getInput();
2043 Location loc = transposeOp.getLoc();
2046 auto broadcastInputTy =
2047 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2048 unsigned inputRank = broadcastInputTy.getRank();
2049 for (
unsigned i = 0; i < inputRank; ++i) {
2050 if (broadcastInputTy.isDynamicDim(i)) {
2051 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2055 broadcastInputTy.getDimSize(i)));
2060 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2061 transposeOp.getLoc(), transposeResultShapes,
2062 broadcastInputTy.getElementType());
2065 Value transposeResult =
2067 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2071 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2096 if (llvm::isa<RankedTensorType>(initType))
2125 void BroadcastOp::getAsmResultNames(
2127 if (!getResults().empty())
2128 setNameFn(getResults().front(),
"broadcasted");
2140 auto inputType = getInput().getType();
2141 auto initType = getInit().getType();
2143 int64_t inputRank = inputType.getRank();
2144 int64_t initRank = initType.getRank();
2146 auto inputShape = inputType.getShape();
2147 auto initShape = initType.getShape();
2149 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2150 return emitOpError() <<
"input rank plus added dimensions does not "
2151 "match init rank. input rank: "
2153 <<
", dimensions size: " << dimensionsRef.size()
2154 <<
", init rank: " << initRank;
2157 if (dim < 0 || dim >= initRank)
2158 return emitOpError() <<
"dimension " << idx
2159 <<
" is out of range. expected range: [0, "
2160 << initRank - 1 <<
"], got: " << dim;
2165 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2166 if (!llvm::is_contained(dimensionsRef, dim))
2167 dimMap.push_back(dim);
2170 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2173 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2174 return emitOpError() <<
"input dim " << inputDimIdx
2175 <<
" should match init dim " << initDimIdx
2176 <<
". input: " << inputShape[inputDimIdx]
2177 <<
", init: " << initShape[initDimIdx];
2184 int64_t rank = getInit().getType().getRank();
2188 ArrayAttr BroadcastOp::getIndexingMaps() {
2190 int64_t rank = getInit().getType().getRank();
2196 void BroadcastOp::getEffects(
2208 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2216 if (getNumOperands() > 0)
2217 p <<
' ' << getOperands();
2219 if (getNumOperands() > 0)
2220 p <<
" : " << getOperandTypes();
2235 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2236 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2237 return op.emitOpError(
"expected number of yield values (")
2238 << op.getNumOperands()
2239 <<
") to match the number of inits / outs operands of the enclosing "
2240 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2242 for (
OpOperand &opOperand : op->getOpOperands()) {
2244 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2246 if (isa<MemRefType, RankedTensorType>(elementType))
2248 if (opOperand.get().getType() != elementType)
2249 return op.emitOpError(
"type of yield operand ")
2250 << (opOperand.getOperandNumber() + 1) <<
" ("
2251 << opOperand.get().getType() <<
") doesn't match "
2252 <<
"the element type of the enclosing linalg.generic op ("
2253 << elementType <<
")";
2259 auto *parentOp = (*this)->getParentOp();
2260 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2261 return emitOpError(
"expected single non-empty parent region");
2263 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2266 return emitOpError(
"expected parent op with LinalgOp interface");
2274 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2276 return emitOpError(
"expected parent op with LinalgOp interface");
2277 if (linalgOp.getNumLoops() <= getDim())
2278 return emitOpError(
"expected dim (")
2279 << getDim() <<
") to be lower than the number of loops ("
2280 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2286 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2288 #define GET_OP_CLASSES
2289 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2291 #define GET_OP_CLASSES
2292 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2293 #define GET_OP_CLASSES
2294 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2311 for (
unsigned i = 0; i < num; ++i)
2318 auto rangeA = llvm::make_range(a.begin(), a.end());
2319 auto rangeB = llvm::make_range(b.begin(), b.end());
2320 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2321 return llvm::to_vector<4>(concatRanges);
2325 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2327 for (
auto size : memref.getShape())
2334 if (
auto as = memref.getMemorySpace()) {
2335 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2336 ss <<
"as" << attr.getInt();
2342 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2345 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2358 assert(isa<LinalgOp>(op));
2360 std::string fun =
"";
2362 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2363 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2364 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2365 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2369 std::replace(name.begin(), name.end(),
'.',
'_');
2370 llvm::raw_string_ostream ss(name);
2374 return std::string();
2389 LogicalResult matchAndRewrite(LinalgOp op,
2391 for (
OpOperand &opOperand : op->getOpOperands()) {
2395 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2398 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2409 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2412 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2417 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2424 if (castOp->getBlock() != linalgOp->getBlock())
2431 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2434 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2440 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2442 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2445 linalgOp.getDpsInits().end());
2446 outputOperands[resultNumber] = newOperand;
2447 newOperands.append(outputOperands.begin(), outputOperands.end());
2450 linalgOp->result_type_end());
2451 resultTypes[resultNumber] = resultType;
2452 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2459 results[resultNumber] = castBack;
2471 if (linalgOp.isScalar(&opOperand))
2473 Value src = opOperand.get();
2474 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2475 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2483 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2484 Value castSource = castOp.getSource();
2485 auto castSourceType =
2486 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2487 if (castSourceType && castSourceType.hasStaticShape())
2488 sourceShape = castSourceType.getShape();
2494 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2495 if (sourceType.isDynamicDim(i))
2497 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2498 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2508 static void createNewOperandWithStaticSizes(
2512 bool &changeNeeded) {
2514 newOperands.push_back(src);
2515 if (linalgOp.isScalar(opOperand))
2517 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2518 Type resultType = sourceType;
2519 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2520 resultTypes.push_back(resultType);
2524 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2528 bool newOperandNeeded =
false;
2529 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2530 int64_t dimShape = sourceShape[i];
2532 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2533 newShape.push_back(dimShape);
2539 newShape.push_back(affineExprToSize[dimExpr]);
2540 newOperandNeeded =
true;
2543 if (newOperandNeeded) {
2544 changeNeeded =
true;
2547 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2549 newOperands[index] = newOperand;
2551 if (linalgOp.isDpsInit(opOperand))
2552 resultTypes.push_back(resultType);
2561 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2563 if (!linalgOp.hasPureTensorSemantics())
2567 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2568 return !map.isProjectedPermutation();
2578 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2585 bool changeNeeded =
false;
2586 newOperands.reserve(linalgOp->getNumOperands());
2587 resultTypes.reserve(linalgOp.getNumDpsInits());
2590 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2591 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2592 affineExprToSize, linalgOp, newOperands,
2593 resultTypes, changeNeeded);
2602 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2605 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2606 Value newResult = std::get<1>(it);
2607 Value oldResult = std::get<0>(it);
2610 replacements.push_back(
2611 (newType != oldType)
2612 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2615 rewriter.
replaceOp(linalgOp, replacements);
2630 ShapedType inputType = getInputOperandType();
2631 ShapedType outputType = getOutputOperandType();
2636 return emitOpError(
"incompatible output shape");
2638 int64_t inputRank = getInputOperandRank();
2639 int64_t dimension = getDimension();
2640 if ((dimension < 0) || (dimension >= inputRank))
2641 return emitOpError(
"incorrect dimension specified");
2647 int64_t operandRank = getInputOperandRank();
2650 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2651 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2652 Value source = getInput();
2653 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2654 loopBounds[dim].offset = zero;
2655 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2656 loopBounds[dim].stride = one;
2663 utils::IteratorType::parallel);
2664 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2665 return iteratorTypes;
2668 FailureOr<TilingResult>
2672 int64_t rank = getInputOperandRank();
2677 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2679 return emitOpError(
"failed to compute input slice");
2681 tiledOperands.emplace_back(inputSlice->
getResult(0));
2683 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2685 return emitOpError(
"failed to compute output slice");
2687 tiledOperands.emplace_back(outputSlice->
getResult(0));
2690 if (hasPureTensorSemantics())
2691 resultTypes.push_back(tiledOperands[1].
getType());
2693 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2705 if (resultNumber == 0) {
2706 resultOffsets.assign(offsets.begin(), offsets.end());
2707 resultSizes.assign(sizes.begin(), sizes.end());
2722 Location loc = getOperation()->getLoc();
2724 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2725 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2726 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2727 if (!outputShapedType.isDynamicDim(dim)) {
2729 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2736 reifiedReturnShapes.emplace_back(std::move(shapes));
2740 void SoftmaxOp::getEffects(
2744 if (!llvm::isa<MemRefType>(operand.
getType()))
2747 &getOperation()->getOpOperand(index), 0,
2752 for (
OpOperand &operand : getDpsInitsMutable()) {
2753 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2786 int64_t dim,
bool allParallel =
false) {
2788 utils::IteratorType::parallel);
2790 iteratorTypes[dim] = utils::IteratorType::reduction;
2794 for (
int i = 0; i < inputRank; i++) {
2801 return std::make_tuple(iteratorTypes, indexingMaps);
2806 template <
typename T>
2809 auto inputType = cast<ShapedType>(input.
getType());
2811 int64_t inputRank = inputShape.size();
2812 auto [iteratorTypes, indexingMaps] =
2814 assert(indexingMaps.size() == 2 &&
2815 "We should have two maps: 1 for the input, 1 for the output");
2816 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2818 auto genericOp = builder.
create<linalg::GenericOp>(
2819 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2821 Value result = b.create<T>(loc, args[0], args[1]);
2822 b.create<linalg::YieldOp>(loc, result);
2832 auto inputType = cast<ShapedType>(input.
getType());
2834 int64_t inputRank = inputShape.size();
2836 builder, inputRank, dim,
true);
2837 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2838 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2840 indexingMaps.push_back(indexingMaps[0]);
2841 auto genericOp = builder.
create<linalg::GenericOp>(
2844 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2845 Value result = b.create<math::ExpOp>(loc, diff);
2846 b.create<linalg::YieldOp>(loc, result);
2857 Value denominator,
Value output, int64_t dim) {
2858 auto inputType = cast<ShapedType>(numerator.
getType());
2860 int64_t inputRank = inputShape.size();
2862 builder, inputRank, dim,
true);
2863 assert(indexingMaps.size() == 2 &&
2864 "We should have one map for each input (2)");
2865 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2867 indexingMaps.push_back(indexingMaps[0]);
2868 auto genericOp = builder.
create<linalg::GenericOp>(
2870 indexingMaps, iteratorTypes,
2872 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2873 b.create<linalg::YieldOp>(loc, result);
2897 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2901 Value input = getInput();
2902 ShapedType inputType = getInputOperandType();
2903 Type elementType = inputType.getElementType();
2904 int64_t reductionDim = getDimension();
2906 Value output = getOutput();
2907 dims.erase(dims.begin() + reductionDim);
2909 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2911 elementType, b, loc,
2913 Value neutralForMaxFInit =
2914 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2917 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2926 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2928 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2932 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2941 auto filterType = cast<ShapedType>(getFilter().
getType());
2943 int64_t filterH = filterShape[getFilterHDim()];
2944 int64_t filterW = filterShape[getFilterWDim()];
2948 if (filterH != r && filterH != 1)
2949 return emitOpError(
"expect filter height either equals to r or 1");
2950 if (filterW != r && filterW != 1)
2951 return emitOpError(
"expect filter width either equals to r or 1");
2952 if (filterH == 1 && filterW == 1)
2953 return emitOpError(
"expect either filter height or width equals to r");
2956 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2957 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2958 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2959 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2961 auto outputType = cast<ShapedType>(getOutput().
getType());
2964 return emitOpError(
"the output shape is not expected");
2970 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2974 Value filter = getFilter();
2975 int64_t filterRank = getFilterOperandRank();
2977 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2978 loopBounds[dim].offset = zeroAttr;
2979 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
2980 loopBounds[dim].stride = oneAttr;
2986 WinogradFilterTransformOp::getLoopIteratorTypes() {
2987 int64_t filterRank = getFilterOperandRank();
2989 utils::IteratorType::parallel);
2990 return iteratorTypes;
2998 ShapedType filterType = getFilterOperandType();
3000 int64_t filterH = filterShape[getFilterHDim()];
3001 int64_t filterW = filterShape[getFilterWDim()];
3004 int64_t alpha = m + r - 1;
3005 int64_t alphaH = filterH != 1 ? alpha : 1;
3006 int64_t alphaW = filterW != 1 ? alpha : 1;
3010 resultOffsets.append(
3011 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3013 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3029 ShapedType filterType = getFilterOperandType();
3031 int64_t filterH = filterShape[getFilterHDim()];
3032 int64_t filterW = filterShape[getFilterWDim()];
3038 sliceOffsets.append(
3039 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3040 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3041 sizes[getFilterCDim()]});
3042 int64_t filterRank = getFilterOperandRank();
3045 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3046 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3047 tiledOperands.emplace_back(filterSlice);
3054 int64_t outputRank = getOutputOperandRank();
3056 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3057 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3058 tiledOperands.emplace_back(outputSlice);
3061 resultTypes.push_back(tiledOperands[1].
getType());
3063 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3076 auto inputType = cast<ShapedType>(getInput().
getType());
3078 int64_t inputH = inputShape[getInputHDim()];
3079 int64_t inputW = inputShape[getInputWDim()];
3082 int64_t tileSize = m + r - 1;
3084 auto outputType = cast<ShapedType>(getOutput().
getType());
3086 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3087 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3090 if (ShapedType::isDynamic(inputH)) {
3091 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3092 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3094 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3095 expectedOutputShape[getOutputTileHDim()] =
3096 leftTransform ? (inputH - (r - 1)) / m : inputH;
3098 if (ShapedType::isDynamic(inputW)) {
3099 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3100 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3102 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3103 expectedOutputShape[getOutputTileWDim()] =
3104 rightTransform ? (inputW - (r - 1)) / m : inputW;
3106 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3107 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3110 return emitOpError(
"the output shape is not expected");
3116 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3120 Value output = getOutput();
3121 int64_t outputRank = getOutputOperandRank();
3123 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3124 loopBounds[dim].offset = zeroAttr;
3126 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3127 loopBounds[dim].stride = oneAttr;
3133 WinogradInputTransformOp::getLoopIteratorTypes() {
3134 int64_t outputRank = getOutputOperandRank();
3136 utils::IteratorType::parallel);
3137 return iteratorTypes;
3145 ShapedType outputType = getOutputOperandType();
3147 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3148 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3152 int64_t alpha = m + r - 1;
3153 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3154 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3159 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3160 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3161 offsets[getOutputCDim()]});
3162 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3163 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3164 sizes[getOutputCDim()]});
3175 FailureOr<TilingResult>
3183 ShapedType outputType = getOutputOperandType();
3185 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3186 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3190 auto identityAffineMap =
3192 auto offsetAffineMap =
3195 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3196 offsets[getOutputTileHDim()]);
3198 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3199 offsets[getOutputTileWDim()]);
3203 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3205 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3212 sliceOffsets.append(
3213 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3219 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3220 int64_t inputRank = getInputOperandRank();
3222 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3223 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3224 tiledOperands.emplace_back(inputSlice);
3231 int64_t outputRank = getOutputOperandRank();
3233 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3234 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3235 tiledOperands.emplace_back(outputSlice);
3238 resultTypes.push_back(tiledOperands[1].
getType());
3240 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3253 auto valueType = cast<ShapedType>(getValue().
getType());
3255 int64_t valueH = valueShape[getValueAlphaHDim()];
3256 int64_t valueW = valueShape[getValueAlphaWDim()];
3257 int64_t valueTileH = valueShape[getValueTileHDim()];
3258 int64_t valueTileW = valueShape[getValueTileWDim()];
3261 bool leftTransform = valueH != 1;
3262 bool rightTransform = valueW != 1;
3264 int64_t outputRank = getOutputOperandRank();
3266 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3267 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3269 if (valueH != (leftTransform ? m + r - 1 : 1))
3270 return emitOpError(
"expect input height equals to input tile size");
3271 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3273 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3274 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3276 if (valueW != (rightTransform ? m + r - 1 : 1))
3277 return emitOpError(
"expect input width equals to input tile size");
3278 expectedOutputShape[getOutputWDim()] =
3279 (rightTransform ? m : 1) * valueTileW;
3281 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3282 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3284 auto outputType = cast<ShapedType>(getOutput().
getType());
3287 return emitOpError(
"the output shape is not expected");
3293 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3297 Value value = getValue();
3298 int64_t valueRank = getValueOperandRank();
3300 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3301 loopBounds[dim].offset = zeroAttr;
3303 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3304 loopBounds[dim].stride = oneAttr;
3310 WinogradOutputTransformOp::getLoopIteratorTypes() {
3311 int64_t valueRank = getValueOperandRank();
3313 utils::IteratorType::parallel);
3314 return iteratorTypes;
3325 auto identityAffineMap =
3330 ShapedType valueType = getValueOperandType();
3332 int64_t valueH = valueShape[0];
3333 int64_t valueW = valueShape[1];
3335 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3336 offsets[getValueTileHDim()]);
3338 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3339 offsets[getValueTileWDim()]);
3341 builder, loc, affineMap, sizes[getValueTileHDim()]);
3343 builder, loc, affineMap, sizes[getValueTileWDim()]);
3353 resultOffsets.append(
3354 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3356 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3375 ShapedType valueType = getValueOperandType();
3377 int64_t alphaH = valueShape[getValueAlphaHDim()];
3378 int64_t alphaW = valueShape[getValueAlphaWDim()];
3382 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3383 offsets[getValueTileWDim()], offsets[getValueNDim()],
3384 offsets[getValueFDim()]});
3385 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3386 sizes[getValueTileWDim()], sizes[getValueNDim()],
3387 sizes[getValueFDim()]});
3388 int64_t valueRank = getValueOperandRank();
3390 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3391 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3392 tiledOperands.emplace_back(valueSlice);
3399 int64_t outputRank = getOutputOperandRank();
3401 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3402 loc, getOutput(), resultOffsets, resultSizes, strides);
3403 tiledOperands.emplace_back(outputSlice);
3406 resultTypes.push_back(tiledOperands[1].
getType());
3408 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3427 llvm::set_union(explicitSet, defaultSet);
3428 return explicitSet == defaultSet;
3448 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3450 auto opIndexingMap = opIndexingMaps[opIndex];
3451 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3454 return matmulOp->emitOpError()
3455 <<
"Unexpected dim expression in map result.";
3458 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3459 return matmulOp->emitOpError()
3460 <<
"Invalid broadcast requested, should be (d2).";
3470 AffineMap defaultIndexingMap,
bool isLHS) {
3473 return batchMatmulOp->emitOpError()
3474 <<
"Unexpected result dim expression (outside the set of default "
3479 return batchMatmulOp->emitOpError()
3480 <<
"no. of result dim expressions exceeds 3.";
3482 auto hasValidBatchDim = [](
AffineMap map) {
3489 if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3490 return batchMatmulOp->emitOpError() <<
"Invalid broadcast requested.";
3491 }
else if (!hasValidBatchDim(opIndexingMap)) {
3492 return batchMatmulOp->emitOpError()
3493 <<
"Invalid batch dimension expression.";
3504 return batchMatmulOp->emitOpError()
3505 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3508 auto areValidOutputResultDim = [](
AffineMap outputMap) {
3509 return outputMap.getResult(0).isFunctionOfDim(0) &&
3510 outputMap.getResult(1).isFunctionOfDim(1) &&
3511 outputMap.getResult(2).isFunctionOfDim(2);
3514 if (!areValidOutputResultDim(opIndexingMap))
3515 return batchMatmulOp->emitOpError()
3516 <<
"Invalid output map result dimension.";
3523 static LogicalResult
3527 batchMatmulOp.getIndexingMapsArray();
3529 batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3531 if (opIndexingMaps.size() != 3)
3532 return batchMatmulOp->emitOpError()
3533 <<
"Indexing_map attribute must have 3 affine maps.";
3535 auto opIndexingMap = opIndexingMaps[opIndex];
3536 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3538 if (opIndex == 2 && failed(
verifyOutputMap(batchMatmulOp, opIndexingMap)))
3541 if (failed(
verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3563 return indexingMaps;
3568 utils::IteratorType::parallel,
3569 utils::IteratorType::reduction};
3572 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3574 std::string MatmulOp::getLibraryCallName() {
3578 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3582 bool MatmulOp::hasUserDefinedMaps() {
3586 return defaultMaps != explicitMaps;
3594 "MatmulOp regionBuilder expects 3 (>=0) args");
3595 RegionBuilderHelper helper(b, block);
3598 TypeFn castVal = TypeFn::cast_signed;
3599 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3600 return attr.
getName() ==
"cast";
3602 if (castIter != attrs.end()) {
3603 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3611 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3613 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3614 yields.push_back(value4);
3615 helper.yieldOutputs(yields);
3619 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3620 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3631 ArrayAttr arrayAttr;
3635 if (llvm::any_of(arrayAttr,
3636 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3638 <<
"element of indexing_maps array is not an affine_map";
3645 if (failed(indexingMapsAttr))
3648 if (*indexingMapsAttr ==
nullptr) {
3649 auto indexingMapAttrs = llvm::map_to_vector(
3650 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3655 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3657 MatmulOp::getRegionBuilder());
3662 MatmulOp::getDefaultIndexingMaps(
getContext()),
3664 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3665 p <<
" indexing_maps = [";
3666 llvm::interleaveComma(getIndexingMaps(), p,
3672 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3680 if (!hasUserDefinedMaps())
3683 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3694 void MatmulOp::getEffects(
3697 if (hasPureTensorSemantics())
3711 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3721 for (
auto result : outAffineMap.
getResults()) {
3722 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3723 assert(dimExpr &&
"affine_map is a projected permutation");
3724 dimsInOutput[dimExpr.getPosition()] =
true;
3728 for (
auto dimOccursInOutput : dimsInOutput)
3729 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3730 : utils::IteratorType::reduction);
3732 return iteratorTypes;
3735 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3741 "ContractOp regionBuilder expects 3 args");
3742 RegionBuilderHelper helper(b, block);
3744 TypeFn castSignedness = TypeFn::cast_signed;
3745 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3746 return attr.
getName() ==
"cast";
3748 if (castIter != attrs.end()) {
3749 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3755 Value lhsAtOutType =
3756 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3757 Value rhsAtOutType =
3758 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3759 Value productAtOutType =
3760 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3763 helper.yieldOutputs({result});
3768 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3770 "expected 'indexing_maps' attribute");
3771 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3778 p <<
" indexing_maps = [";
3779 llvm::interleaveComma(getIndexingMaps(), p,
3783 p, getOperation(), getInputs(), getOutputs(),
3784 {
"indexing_maps",
"operandSegmentSizes"});
3788 int iterationSpaceDims = -1;
3797 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3798 bool isInput) -> LogicalResult {
3801 return emitError(
"provided affine_map is not a projected permutation");
3804 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3806 return emitError(
"ranks of shaped operand and results of corresponding "
3807 "affine_map differ");
3809 return emitError(
"affine_map specifies shaped access while operand has "
3814 if (iterationSpaceDims == -1) {
3818 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
3819 return emitError(
"iteration spaces of provided affine_maps differ");
3824 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3826 llvm_unreachable(
"affine_map is a projected permutation");
3829 inOccurrences[affineDimExpr.getPosition()] += 1;
3831 outOccurrences[affineDimExpr.getPosition()] += 1;
3837 for (
auto &&[affineMap, operandType, isInput] :
3838 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3840 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3844 bool hasContractingDim =
false;
3845 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3846 size_t inOccCount = inOccurrences[dimIndex];
3847 size_t outOccCount = outOccurrences[dimIndex];
3850 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3852 if (inOccCount == 0 && outOccCount == 0)
3853 return emitError() <<
"iteration space dim at index " << dimIndex
3854 <<
" not used to access any operand";
3865 if (inOccCount == 1 && outOccCount != 1)
3867 <<
"iteration space dim at index " << dimIndex
3868 <<
" is neither a contracting dim nor of parallel iteration type";
3871 if (!hasContractingDim)
3872 return emitError(
"'indexing_maps' do not specify a contracting dimension");
3881 void ContractOp::getEffects(
3884 if (hasPureTensorSemantics())
3897 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
3901 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
3902 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
3903 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
3904 return indexingMaps;
3909 utils::IteratorType::parallel, utils::IteratorType::parallel,
3910 utils::IteratorType::parallel, utils::IteratorType::reduction};
3913 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
3915 std::string BatchMatmulOp::getLibraryCallName() {
3921 bool BatchMatmulOp::hasUserDefinedMaps() {
3925 return defaultMaps != explicitMaps;
3929 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
3931 "Expected less than 3 result dim expr.");
3932 bool isValid =
false;
3933 enum Indices { batchPos, mPos, nPos, kPos };
3950 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3951 RegionBuilderHelper helper(b, block);
3956 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
3958 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
3959 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3961 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
3962 yields.push_back(addVal);
3963 helper.yieldOutputs(yields);
3979 if (!isa<AffineMapAttr>(mapAttr)) {
3981 "expected affine map attribute");
3983 indexingMapsAttr.push_back(mapAttr);
3993 if (indexingMapsAttr.empty()) {
3994 indexingMapsAttr = llvm::map_to_vector(
3995 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4002 BatchMatmulOp::getNumRegionArgs(),
4003 BatchMatmulOp::getRegionBuilder());
4008 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4013 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4015 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4016 p <<
" indexing_maps = [";
4017 llvm::interleaveComma(getIndexingMaps(), p,
4027 if (!hasUserDefinedMaps())
4030 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4037 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4042 void BatchMatmulOp::getEffects(
4045 if (hasPureTensorSemantics())
4068 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4070 .take_back(mixedTiles.size()),
4072 int64_t shape = std::get<0>(it);
4073 if (shape == ShapedType::kDynamic) {
4074 newMixedTileSizes.push_back(std::get<1>(it));
4081 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4083 newMixedTileSizes.push_back(
tile);
4086 "tile size and dim size don't match!");
4087 newMixedTileSizes.push_back(
4092 return newMixedTileSizes;
4095 template <
typename OpTy>
4096 static LogicalResult
4099 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4100 "applies to only pack or unpack operations");
4101 int64_t destRank = op.getDestRank();
4103 reifiedReturnShapes[0] =
4108 template <
typename OpTy>
4110 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4111 "applies to only pack or unpack operations");
4115 assert(tiles.size() == dimsToTile.size() &&
4116 "tiles must match indices of dimension to block");
4118 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4119 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4120 return dimAndTileMapping;
4123 template <
typename OpTy>
4125 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4126 "applies to only pack or unpack operations");
4129 unsigned dynamicValIndex = 0;
4130 for (int64_t staticTile : op.getStaticInnerTiles()) {
4131 if (!ShapedType::isDynamic(staticTile))
4134 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4136 return mixedInnerTiles;
4139 template <
typename OpTy>
4141 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4142 "applies to only pack or unpack operations");
4155 size_t dimsPosSize = dimsPos.size();
4156 if (dimsPosSize > rank)
4159 for (int64_t dim : dimsPos)
4160 uniqued.insert(dim);
4161 if (dimsPosSize != uniqued.size())
4163 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4164 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4173 sourceShape.size() == limitShape.size() &&
4174 "expected source shape rank, and limit of the shape to have same rank");
4175 return llvm::all_of(
4176 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4177 int64_t sourceExtent = std::get<0>(it);
4178 int64_t limit = std::get<1>(it);
4179 return ShapedType::isDynamic(sourceExtent) ||
4180 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4184 template <
typename OpTy>
4186 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4187 "applies to only pack or unpack operations");
4188 Operation *op = packOrUnPack.getOperation();
4192 return llvm::any_of(
4198 if (hasZeros(mixedTiles))
4199 return op->
emitError(
"invalid zero tile factor");
4202 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4203 ? packOrUnPack.getSourceType()
4204 : packOrUnPack.getDestType();
4205 size_t unpackedRank = unpackedType.getRank();
4209 return op->
emitError(
"invalid inner_dims_pos vector");
4211 return op->
emitError(
"invalid outer_dims_perm vector");
4212 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4213 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4217 if (mixedTiles.size() > unpackedRank) {
4218 return op->
emitError(
"tiling factors must be less than or equal to the "
4219 "input rank for pack or output rank for unpack");
4223 "tiling factors must equal the number of dimensions to tile");
4226 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4227 ? packOrUnPack.getDestType()
4228 : packOrUnPack.getSourceType();
4229 size_t packedRank = packedType.getRank();
4231 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4232 if (expectedPackedRank != packedRank) {
4234 "packed rank != (unpacked rank + num tiling factors), got ")
4235 << packedRank <<
" != " << expectedPackedRank;
4241 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4242 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4243 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4244 return op->
emitError(
"the shape of output is not large enough to hold the "
4245 "packed data. Expected at least ")
4246 << expectedPackedType <<
", got " << packedType;
4249 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4251 [](std::tuple<int64_t, OpFoldResult> it) {
4252 int64_t shape = std::get<0>(it);
4253 if (Attribute attr =
4254 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4255 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4256 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4257 return shape == staticTileSize;
4259 return ShapedType::isDynamic(shape);
4261 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4262 "tiled dimension in the packed type");
4274 struct PackOrUnPackTransposeResult {
4281 template <
typename OpTy>
4282 static PackOrUnPackTransposeResult
4286 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4287 "applies to only pack or unpack operations");
4288 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4289 "some permutation must be non-empty");
4290 PackOrUnPackTransposeResult metadata;
4291 metadata.innerDimsPos =
4293 metadata.innerTiles =
4295 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4296 ? packOrUnPackOp.getSourceRank()
4297 : packOrUnPackOp.getDestRank();
4298 metadata.outerDimsPerm =
4299 packOrUnPackOp.getOuterDimsPerm().empty()
4300 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4302 if (!innerPermutation.empty()) {
4303 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4305 "invalid inner permutation");
4309 if (!outerPermutation.empty()) {
4310 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4312 "invalid outer permutation");
4322 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4323 setNameFn(getResult(),
"pack");
4329 std::optional<Value> paddingValue,
4332 "number of tile sizes specified must match the specified number of "
4333 "original dimensions to be tiled");
4337 build(builder, state, dest.
getType(), source, dest,
4338 paddingValue ? *paddingValue :
nullptr,
4364 ShapedType inputType = getSourceType();
4365 int64_t inputRank = inputType.getRank();
4366 return getDestType().getShape().take_front(inputRank);
4371 auto packedShape = getDestType().getShape();
4375 res.push_back(packedShape[index]);
4386 outputShape.take_front(inputShape.size()));
4389 "expected output and outer_dims_perm to have same size");
4394 if (ShapedType::isDynamic(inputShape[pos]))
4398 if (!constantTile) {
4399 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4400 (inputShape[pos] % outputTileSizes[pos] != 0))
4402 }
else if (inputShape[pos] % (*constantTile) != 0) {
4416 auto paddingValue = getPaddingValue();
4419 return emitOpError(
"expected padding_value has ")
4420 << getSourceType().getElementType()
4421 <<
" but got: " << paddingValue.getType();
4424 if (!paddingValue &&
4425 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4426 getDestType().
getShape(), getOuterDimsPerm(),
4429 "invalid tile factor or output size provided. Only full tiles are "
4430 "supported when padding_value is not set");
4440 for (
auto o : ofrs) {
4442 if (llvm::dyn_cast_if_present<Value>(o))
4443 result.push_back(ShapedType::kDynamic);
4458 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4460 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4461 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4464 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4465 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4473 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4488 builder, loc, ceilDivExpr,
4489 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4493 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4504 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4505 if (!ShapedType::isDynamic(resultTypeShape[i]))
4516 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4538 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4539 if (ShapedType::isDynamic(value))
4540 mixedSizes.push_back(
4545 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
4546 int64_t dimPos = std::get<0>(it);
4548 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4551 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
4553 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4554 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4555 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4562 *
this, innerPermutation, outerPermutation);
4563 Value transposedDest =
4564 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4565 metadata.innerDimsPos, metadata.outerDimsPerm);
4566 return b.
create<PackOp>(loc, getSource(), transposedDest,
4567 metadata.innerDimsPos, metadata.innerTiles,
4568 getPaddingValue(), metadata.outerDimsPerm);
4572 template <
typename OpTy>
4574 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4575 "applies to only pack or unpack operations");
4576 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4578 : op.getSourceType();
4580 for (
auto [dimDest,
tile] : llvm::zip(
4581 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4583 if (!constTileSize || ShapedType::isDynamic(dimDest))
4590 if (getPaddingValue())
4605 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4607 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4619 auto packTiles = packOp.getMixedTiles();
4620 auto unPackTiles = unPackOp.getMixedTiles();
4621 if (packTiles.size() != unPackTiles.size())
4623 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4632 auto srcType = op.getSourceType();
4633 if (llvm::any_of(op.getInnerDimsPos(),
4634 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4636 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4638 return !PackOp::requirePaddingValue(
4639 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4640 op.getOuterDimsPerm(), op.getMixedTiles());
4647 bool changeNeeded =
false;
4648 srcShape.assign(packOp.getSourceType().getShape().begin(),
4649 packOp.getSourceType().getShape().end());
4650 destShape.assign(packOp.getDestType().getShape().begin(),
4651 packOp.getDestType().getShape().end());
4652 llvm::SmallSetVector<int64_t, 4> innerDims;
4653 innerDims.insert(packOp.getInnerDimsPos().begin(),
4654 packOp.getInnerDimsPos().end());
4656 if (!packOp.getOuterDimsPerm().empty())
4658 int srcRank = packOp.getSourceRank();
4659 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4660 if (innerDims.contains(i))
4663 int64_t destPos = i;
4664 if (!inverseOuterDimsPerm.empty())
4665 destPos = inverseOuterDimsPerm[srcPos];
4666 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4667 ShapedType::isDynamic(destShape[destPos])) {
4670 int64_t size = srcShape[srcPos];
4671 if (ShapedType::isDynamic(size))
4672 size = destShape[destPos];
4673 srcShape[srcPos] = size;
4674 destShape[destPos] = size;
4675 changeNeeded =
true;
4677 return changeNeeded;
4680 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4682 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4683 if (unPackOp.getSourceType() != packOp.getDestType())
4685 if (packOp.getPaddingValue() ||
4689 rewriter.
replaceOp(packOp, unPackOp.getSource());
4696 packOp.getPaddingValueMutable().clear();
4705 Value source = packOp.getSource();
4706 if (srcShape != packOp.getSourceType().getShape()) {
4707 auto newSrcType = packOp.getSourceType().clone(srcShape);
4709 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4711 Value dest = packOp.getDest();
4712 RankedTensorType originalResultType = packOp.getDestType();
4713 bool needUpdateDestType = (destShape != originalResultType.getShape());
4714 if (needUpdateDestType) {
4715 auto newDestType = packOp.getDestType().clone(destShape);
4717 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4720 packOp.getSourceMutable().assign(source);
4721 packOp.getDestMutable().assign(dest);
4722 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
4725 if (needUpdateDestType) {
4728 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
4737 template <
typename PackOrUnpackOp>
4739 RankedTensorType packedTensorType) {
4740 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4741 std::is_same<PackOrUnpackOp, UnPackOp>::value,
4742 "Function meant for pack/unpack");
4748 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4755 int64_t packedRank = packedTensorType.getRank();
4765 return llvm::all_of(
4766 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4767 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4770 bool PackOp::isLikePad() {
4771 auto packedTensorType =
4772 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4777 std::optional<Attribute> paddingValue;
4778 if (
auto pad = adaptor.getPaddingValue())
4780 if (
OpFoldResult reshapedSource = reshapeConstantSource(
4781 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4782 getDestType(), paddingValue))
4783 return reshapedSource;
4821 PackOp newOp = rewriter.
create<PackOp>(
4822 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4823 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4827 Value oldResult = op.getResult();
4828 Value newResult = newOp.getResult();
4830 ? rewriter.
create<tensor::CastOp>(
4831 op->getLoc(), oldResult.
getType(), newResult)
4844 void UnPackOp::getAsmResultNames(
4846 setNameFn(getResult(),
"unpack");
4868 ShapedType destType = getDestType();
4869 int64_t destRank = destType.getRank();
4870 return getSourceType().getShape().take_front(destRank);
4875 auto packedShape = getSourceType().getShape();
4879 res.push_back(packedShape[index]);
4901 "number of tile sizes specified must match the specified number of "
4902 "original dimensions to be tiled");
4906 build(builder, state, dest.
getType(), source, dest,
4925 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4927 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4928 if (srcType.isDynamicDim(i))
4929 mixedSizes.push_back(b.
create<tensor::DimOp>(loc, source, i).
getResult());
4931 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4934 applyPermutationToVector<OpFoldResult>(
4938 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
4939 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4941 auto elemType = srcType.getElementType();
4942 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4946 Value transposedSource,
4950 *
this, innerPermutation, outerPermutation);
4951 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4952 metadata.innerDimsPos, metadata.innerTiles,
4953 metadata.outerDimsPerm);
4960 bool changeNeeded =
false;
4961 srcShape.assign(op.getSourceType().getShape().begin(),
4962 op.getSourceType().getShape().end());
4963 destShape.assign(op.getDestType().getShape().begin(),
4964 op.getDestType().getShape().end());
4965 llvm::SmallSetVector<int64_t, 4> innerDims;
4966 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4968 if (!op.getOuterDimsPerm().empty())
4970 int destRank = op.getDestRank();
4971 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4972 if (innerDims.contains(i))
4975 int64_t destPos = i;
4976 if (!inverseOuterDimsPerm.empty())
4977 srcPos = inverseOuterDimsPerm[destPos];
4978 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4979 ShapedType::isDynamic(destShape[destPos])) {
4982 int64_t size = srcShape[srcPos];
4983 if (ShapedType::isDynamic(size))
4984 size = destShape[destPos];
4985 srcShape[srcPos] = size;
4986 destShape[destPos] = size;
4987 changeNeeded =
true;
4989 return changeNeeded;
4992 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4995 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
4996 if (packOp.getSourceType() != unPackOp.getDestType())
4998 if (packOp.getPaddingValue() ||
5002 rewriter.
replaceOp(unPackOp, packOp.getSource());
5006 if (
auto dstStyleOp =
5007 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5008 auto destValue = cast<OpResult>(unPackOp.getDest());
5009 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5011 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5019 Value source = unPackOp.getSource();
5020 if (srcShape != unPackOp.getSourceType().getShape()) {
5021 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5022 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
5023 unPackOp.getSource());
5025 Value dest = unPackOp.getDest();
5026 if (destShape != unPackOp.getDestType().getShape()) {
5027 auto newDestType = unPackOp.getDestType().clone(destShape);
5029 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5032 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5033 unPackOp.getOuterDimsPerm());
5035 unPackOp, unPackOp.getResult().getType(), newOp);
5042 bool UnPackOp::isLikeUnPad() {
5043 RankedTensorType packedTensorType = getSourceType();
5048 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5049 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5051 return reshapedSource;
5080 Value sourceTensor = newOperands[0];
5084 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5090 UnPackOp newOp = rewriter.
create<UnPackOp>(
5091 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5092 newMixedTileSizes, op.getOuterDimsPerm());
5096 Value oldResult = op.getResult();
5097 Value newResult = newOp.getResult();
5099 ? rewriter.
create<tensor::CastOp>(
5100 op->getLoc(), oldResult.
getType(), newResult)
5116 void LinalgDialect::getCanonicalizationPatterns(
5125 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override