37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/StringSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
66 assert(llvm::all_of(outputTypes,
67 [](
Type t) {
return llvm::isa<ShapedType>(t); }));
73 for (
auto containers : {inputTypes, outputTypes}) {
74 for (
auto t : containers) {
85 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
89 regionBuilder(b, *body, attrs);
101 std::optional<TypeRange> resultTensorTypes,
108 if (!resultTensorTypes)
109 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
110 [](
Type type) { return llvm::isa<RankedTensorType>(type); });
117 "operand_segment_sizes",
119 static_cast<int32_t>(outputs.size())}));
133 bool addOperandSegmentSizes =
true) {
134 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
163 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
165 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
169 if (addOperandSegmentSizes) {
178 attrs.
append(
"operand_segment_sizes",
180 {static_cast<int32_t>(inputsOperands.size()),
181 static_cast<int32_t>(outputsOperands.size())}));
186 {static_cast<int32_t>(inputsOperands.size()),
187 static_cast<int32_t>(outputsOperands.size())}));
191 std::optional<RegisteredOperationName> info =
195 return parser.emitError(attrsLoc)
196 <<
"'" << result.name.getStringRef() <<
"' op ";
207 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
208 if (!outputs.empty())
209 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
220 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
223 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
224 "region expects {0} args, got {1}",
225 numRegionArgs, inputTypes.size() + outputTypes.size()));
244 unsigned numRegionArgs,
256 result.
addTypes(outputTensorsTypes);
258 std::unique_ptr<Region> region = std::make_unique<Region>();
270 if (resultTypes.empty())
279 {
"operand_segment_sizes",
282 "linalg.memoized_indexing_maps"});
319 class RegionBuilderHelper {
322 : context(context), block(block) {}
325 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
326 if (!isFloatingPoint(arg))
327 llvm_unreachable(
"unsupported non numeric type");
339 return builder.
create<math::FloorOp>(arg.
getLoc(), arg);
341 return builder.
create<arith::NegFOp>(arg.
getLoc(), arg);
343 llvm_unreachable(
"unsupported unary function");
348 bool allComplex = isComplex(arg0) && isComplex(arg1);
349 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
350 bool allInteger = isInteger(arg0) && isInteger(arg1);
353 if (!allComplex && !allFloatingPoint && !allInteger)
354 llvm_unreachable(
"unsupported non numeric type");
359 return builder.
create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
360 if (allFloatingPoint)
361 return builder.
create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
363 return builder.
create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
364 return builder.
create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
367 return builder.
create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
368 if (allFloatingPoint)
369 return builder.
create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
371 llvm_unreachable(
"unsupported operation: sub with bools");
372 return builder.
create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
375 return builder.
create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
376 if (allFloatingPoint)
377 return builder.
create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
379 return builder.
create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
380 return builder.
create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
381 case BinaryFn::max_signed:
383 if (allFloatingPoint)
384 return builder.
create<arith::MaxFOp>(arg0.
getLoc(), arg0, arg1);
385 return builder.
create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
386 case BinaryFn::min_signed:
388 if (allFloatingPoint)
389 return builder.
create<arith::MinFOp>(arg0.
getLoc(), arg0, arg1);
390 return builder.
create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
391 case BinaryFn::max_unsigned:
393 if (allFloatingPoint)
394 return builder.
create<arith::MaxFOp>(arg0.
getLoc(), arg0, arg1);
395 return builder.
create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
396 case BinaryFn::min_unsigned:
398 if (allFloatingPoint)
399 return builder.
create<arith::MinFOp>(arg0.
getLoc(), arg0, arg1);
400 return builder.
create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
402 llvm_unreachable(
"unsupported binary function");
408 case TypeFn::cast_signed:
409 return cast(toType, operand,
false);
410 case TypeFn::cast_unsigned:
411 return cast(toType, operand,
true);
413 llvm_unreachable(
"unsupported type conversion function");
419 builder.
create<YieldOp>(loc, values);
422 Value constant(
const std::string &value) {
426 return builder.
create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
429 Value index(int64_t dim) {
434 Type getIntegerType(
unsigned width) {
448 auto loc = operand.
getLoc();
452 bool isComplex(
Value value) {
453 return llvm::isa<ComplexType>(value.
getType());
455 bool isFloatingPoint(
Value value) {
456 return llvm::isa<FloatType>(value.
getType());
458 bool isInteger(
Value value) {
459 return llvm::isa<IntegerType>(value.
getType());
484 template <
typename TensorReshapeOp>
489 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
494 auto newInit = rewriter.
create<TensorReshapeOp>(
495 loc, reshapeOp.getResultType(), oldFill.output(),
496 reshapeOp.getReassociation());
511 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
517 Value padValue = padOp.getConstantPaddingValue();
518 if (!padValue || fillOp.value() != padValue)
524 padOp,
"failed to reify tensor.pad op result shape");
526 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
527 padOp.getLoc(), reifiedShape.front(),
528 padOp.getResultType().getElementType());
534 if (replacement.getType() != padOp.getResultType()) {
535 replacement = rewriter.
create<tensor::CastOp>(
536 fillOp.getLoc(), padOp.getResultType(), replacement);
546 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
549 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
551 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
555 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
560 Value firstDest = insertOp.getDest();
561 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
562 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
567 bool disjoint =
false;
568 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
571 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
572 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
573 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
577 int64_t prevStart = prevOp.getStaticOffset(i);
578 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
579 prevOp.getStaticStride(i);
580 int64_t nextStart = insertOp.getStaticOffset(i);
581 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
582 insertOp.getStaticStride(i);
583 if (prevEnd < nextStart || nextEnd < prevStart) {
591 firstDest = prevOp.getDest();
602 Value padValue = srcPadOp.getConstantPaddingValue();
603 if (!padValue || dstFillOp.value() != padValue)
619 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
621 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
625 for (
int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
627 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
632 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
633 newSizes, insertOp.getMixedStrides());
643 .
add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
644 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
645 FoldInsertPadIntoFill>(context);
658 for (
ValueRange container : {inputs, outputs}) {
659 for (
Value v : container) {
661 blockArgLocs.push_back(v.getLoc());
667 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
671 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
673 for (
Value v : getRegionInputArgs())
675 for (
Value v : getRegionOutputArgs())
679 void GenericOp::build(
682 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
685 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
686 iteratorTypes, doc, libraryCall);
690 inputs, outputs, bodyBuild);
693 void GenericOp::build(
697 StringRef libraryCall,
700 build(builder, result, resultTensorTypes, inputs, outputs,
705 return IteratorTypeAttr::get(builder.getContext(), iter);
708 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
709 bodyBuild, attributes);
712 void GenericOp::build(
716 StringRef libraryCall,
719 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
720 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
723 void GenericOp::build(
729 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
731 "", bodyBuild, attributes);
734 void GenericOp::build(
740 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
743 "", bodyBuild, attributes);
750 auto genericAttrNames = linalgTraitAttrNames();
753 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
755 for (
auto attr : (*this)->getAttrs()) {
756 if (attr.getName() == getIteratorTypesAttrName()) {
758 llvm::cast<ArrayAttr>(attr.getValue())
759 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
765 llvm::to_vector(llvm::map_range(
766 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
770 genericAttrs.emplace_back(
771 getIteratorTypesAttrName(),
773 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
774 genericAttrs.push_back(attr);
777 if (!genericAttrs.empty()) {
779 p << genericDictAttr;
786 genericAttrNames.push_back(
"operand_segment_sizes");
787 genericAttrNamesSet.insert(genericAttrNames.back());
789 bool hasExtraAttrs =
false;
791 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
801 if (!getRegion().empty()) {
811 DictionaryAttr dictAttr;
820 dictAttr.getValue().end());
826 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
828 if (!iteratorTypes) {
829 return parser.
emitError(attributeLocation)
830 <<
"expected " << getIteratorTypesAttrName(result.
name)
831 <<
" array attribute";
836 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
837 auto maybeIteratorType = utils::symbolizeIteratorType(s);
838 if (!maybeIteratorType.has_value())
840 <<
"unexpected iterator_type (" << s <<
")";
842 iteratorTypeAttrs.push_back(
859 std::unique_ptr<Region> region = std::make_unique<Region>();
871 result.
addTypes(outputTensorsTypes);
881 for (
auto *operand : inputOperands) {
882 if (!llvm::isa<MemRefType>(operand->get().
getType()))
887 for (
auto *operand : outputOperands) {
888 if (!llvm::isa<MemRefType>(operand->get().
getType()))
897 void GenericOp::getEffects(
901 getDpsInputOperands(), getDpsInitOperands());
919 if (llvm::any_of(genericOp.getIndexingMapsArray(),
920 [](
AffineMap map) { return !map.isIdentity(); }))
926 if (!llvm::hasSingleElement(body))
928 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
933 if (genericOp.hasBufferSemantics()) {
934 if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
935 genericOp.getDpsInputOperand(0)->get() ==
936 genericOp.getDpsInitOperand(0)->get()) {
944 if (!genericOp.hasTensorSemantics())
951 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
952 if (!yieldArg || yieldArg.getOwner() != &body)
954 unsigned argumentNumber = yieldArg.getArgNumber();
955 Value returnedArg = genericOp->getOperand(argumentNumber);
956 Type resultType = genericOp->getResult(yieldVal.index()).getType();
960 if (returnType != resultType) {
965 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
966 genericOp.getLoc(), resultType, returnedArg);
968 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
971 returnedArg = rewriter.
create<tensor::CastOp>(
972 genericOp.getLoc(), resultType, returnedArg);
975 returnedArgs.push_back(returnedArg);
978 if (returnedArgs.size() != genericOp->getNumResults())
980 rewriter.
replaceOp(genericOp, returnedArgs);
989 results.
add<EraseIdentityGenericOp>(context);
1011 for (
Type outputType : outputTypes) {
1012 if (llvm::isa<RankedTensorType>(outputType))
1026 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1028 for (
Value v : getRegionInputArgs())
1033 if (!getResults().empty())
1034 setNameFn(getResults().front(),
"mapped");
1041 build(builder, result,
TypeRange{}, inputs, init);
1046 if (llvm::isa<RankedTensorType>(initType))
1051 inputs, {}, bodyBuild);
1058 bool initFirst =
false) {
1064 for (
auto &operand : operands) {
1066 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1073 payloadOpOperands.push_back(block.
getArguments().back());
1074 for (
const auto &arg : block.
getArguments().drop_back())
1075 payloadOpOperands.push_back(arg);
1084 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1091 std::optional<OperationName> payloadOpName;
1095 if (
failed(operationName))
1099 payloadOpName = operationName.value();
1107 if (payloadOpName.has_value()) {
1141 for (
const auto &[operand, bbArg] :
1143 if (bbArg != operand)
1147 for (
const auto &[operand, bbArg] :
1149 if (bbArg != operand)
1158 std::string attrToElide;
1160 for (
const auto &attr : payloadOp->
getAttrs()) {
1162 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1163 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1164 attrToElide = attr.getName().str();
1165 elidedAttrs.push_back(attrToElide);
1174 Block *mapper = getBody();
1190 [&](
auto arg) { p.printRegionArgument(arg); });
1199 auto *bodyBlock = getBody();
1200 auto blockArgs = bodyBlock->getArguments();
1203 if (getInputs().size() != blockArgs.size())
1204 return emitOpError() <<
"expects number of operands to match the arity of "
1206 << getInputs().size() <<
" and " << blockArgs.size();
1209 for (
const auto &[bbArgType, inputArg] :
1210 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1211 auto inputElemType =
1212 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1213 if (bbArgType != inputElemType) {
1214 return emitOpError() <<
"expected element type of input " << inputElemType
1215 <<
" to match bbArg type " << bbArgType;
1220 auto outputShape = getInit().getType().getShape();
1222 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1223 if (inputElemShape != outputShape) {
1224 return emitOpError() <<
"expected shape of input (" << inputElemShape
1225 <<
") to match shape of output (" << outputShape
1234 int64_t rank = getInit().getType().getRank();
1238 ArrayAttr MapOp::getIndexingMaps() {
1239 Builder builder(getContext());
1240 int64_t rank = getInit().getType().getRank();
1241 int64_t numIndexingMaps = getOperands().size();
1246 void MapOp::getEffects(
1250 getDpsInputOperands(), getDpsInitOperands());
1257 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1259 for (
Value v : getRegionInputArgs())
1261 for (
Value v : getRegionOutputArgs())
1262 setNameFn(v,
"init");
1265 void ReduceOp::getAsmResultNames(
1267 if (!getResults().empty())
1268 setNameFn(getResults().front(),
"reduced");
1271 void ReduceOp::build(
1276 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1280 for (
Value init : inits) {
1282 if (llvm::isa<RankedTensorType>(initType))
1288 inputs, inits, bodyBuild);
1293 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1295 utils::IteratorType::parallel);
1296 for (int64_t reductionDim : getDimensions())
1297 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1298 return iteratorTypes;
1301 ArrayAttr ReduceOp::getIndexingMaps() {
1303 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1310 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1311 affineMaps.push_back(resultMap);
1315 void ReduceOp::getEffects(
1319 getDpsInputOperands(), getDpsInitOperands());
1324 StringRef attributeName) {
1333 std::optional<OperationName> payloadOpName;
1337 if (
failed(operationName))
1341 payloadOpName = operationName.value();
1352 if (payloadOpName.has_value()) {
1372 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1376 Block *mapper = getBody();
1392 [&](
auto arg) { p.printRegionArgument(arg); });
1403 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1404 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1405 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1406 return emitOpError() <<
"expects all inputs to have the same shapes. "
1407 "Shape at input-index "
1409 <<
" is not equal to the shape at input-index 0.";
1412 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1413 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1414 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1415 return emitOpError() <<
"expects all outputs to have the same shapes. "
1416 "Shape at output-index "
1418 <<
" is not equal to the shape at output-index 0.";
1421 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1422 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1425 for (int64_t dimension : dimensionsRef) {
1426 if (dimension < 0 || dimension >= inputType.getRank()) {
1427 return emitOpError()
1428 <<
"dimensions for reduction should be in the range [0, "
1429 << inputType.getRank() - 1 <<
"].";
1431 dimensionsToReduce.insert(dimension);
1434 auto inputDims = inputType.getShape();
1435 auto initDims = initType.getShape();
1440 if (!dimensionsToReduce.count(en.index()))
1441 reducedInputDims.push_back(en.value());
1444 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1445 return emitOpError() <<
"number of dimensions after reduction "
1446 << reducedInputDims.size()
1447 <<
" doesn't match the init rank "
1448 << initType.getRank();
1451 if (reducedInputDims != initDims)
1452 return emitOpError() <<
"init dimensions [" << initDims
1453 <<
"] doesn't match input dimensions after reduction ["
1454 << reducedInputDims <<
"]";
1456 Block *block = getBody();
1458 return emitOpError()
1459 <<
"mismatching number of operands and block arguments";
1462 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1463 Type inputElementType =
1464 llvm::cast<ShapedType>(input.getType()).getElementType();
1465 if (inputElementType != bbArg.getType())
1466 return emitOpError()
1467 <<
"input element type " << inputElementType
1468 <<
" does not match corresponding block argument type "
1473 for (
auto [output, bbArg] :
1474 llvm::zip(getDpsInitOperands(),
1476 auto outputElementType =
1477 llvm::cast<ShapedType>(output->get().getType()).getElementType();
1478 if (outputElementType != bbArg.getType())
1479 return emitOpError()
1480 <<
"output element type " << outputElementType
1481 <<
" does not match corresponding block argument type "
1496 b.
create<linalg::YieldOp>(loc, args[0]);
1511 if (llvm::isa<RankedTensorType>(initType))
1540 void TransposeOp::getAsmResultNames(
1542 if (!getResults().empty())
1543 setNameFn(getResults().front(),
"transposed");
1557 return emitOpError(
"permutation is not valid");
1559 auto inputType = getInput().getType();
1560 auto initType = getInit().getType();
1562 int64_t rank = inputType.getRank();
1564 if (rank != initType.getRank())
1565 return emitOpError() <<
"input rank " << rank
1566 <<
" does not match init rank " << initType.getRank();
1568 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1569 return emitOpError() <<
"size of permutation " << permutationRef.size()
1570 <<
" does not match the argument rank " << rank;
1572 auto inputDims = inputType.getShape();
1573 auto initDims = initType.getShape();
1575 for (int64_t i = 0; i < rank; ++i) {
1576 int64_t inputDim = inputDims[permutationRef[i]];
1577 int64_t initDim = initDims[i];
1579 if (inputDim != initDim) {
1580 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1581 <<
" doesn't match dim(input, permutation[" << i
1582 <<
"]) = " << inputDim;
1590 int64_t rank = getInit().getType().getRank();
1594 ArrayAttr TransposeOp::getIndexingMaps() {
1595 Builder builder(getContext());
1596 int64_t rank = getInit().getType().getRank();
1600 llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
1603 void TransposeOp::getEffects(
1607 getDpsInputOperands(), getDpsInitOperands());
1625 if (llvm::isa<RankedTensorType>(initType))
1654 void BroadcastOp::getAsmResultNames(
1656 if (!getResults().empty())
1657 setNameFn(getResults().front(),
"broadcasted");
1670 auto inputType = getInput().getType();
1671 auto initType = getInit().getType();
1673 int64_t inputRank = inputType.getRank();
1674 int64_t initRank = initType.getRank();
1676 auto inputShape = inputType.getShape();
1677 auto initShape = initType.getShape();
1679 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
1680 return emitOpError() <<
"input rank plus added dimensions does not "
1681 "match init rank. input rank: "
1683 <<
", dimensions size: " << dimensionsRef.size()
1684 <<
", init rank: " << initRank;
1687 if (dim < 0 || dim >= initRank)
1688 return emitOpError() <<
"dimension " << idx
1689 <<
" is out of range. expected range: [0, "
1690 << initRank - 1 <<
"], got: " << dim;
1695 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
1696 if (!llvm::is_contained(dimensionsRef, dim))
1697 dimMap.push_back(dim);
1700 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
1703 if (inputShape[inputDimIdx] != initShape[initDimIdx])
1704 return emitOpError() <<
"input dim " << inputDimIdx
1705 <<
" should match init dim " << initDimIdx
1706 <<
". input: " << inputShape[inputDimIdx]
1707 <<
", init: " << initShape[initDimIdx];
1714 int64_t rank = getInit().getType().getRank();
1718 ArrayAttr BroadcastOp::getIndexingMaps() {
1719 Builder builder(getContext());
1720 int64_t rank = getInit().getType().getRank();
1726 void BroadcastOp::getEffects(
1730 getDpsInputOperands(), getDpsInitOperands());
1738 if (getNumOperands() > 0)
1739 p <<
' ' << getOperands();
1741 if (getNumOperands() > 0)
1742 p <<
" : " << getOperandTypes();
1759 return op.
emitOpError(
"expected number of yield values (")
1760 << linalgOp.getNumDpsInits()
1761 <<
") to match the number of operands of the enclosing "
1766 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1768 if (opOperand.get().getType() != elementType)
1770 << (opOperand.getOperandNumber() + 1) <<
" ("
1771 << opOperand.get().getType() <<
") doesn't match "
1772 <<
"the element type of the enclosing linalg.generic op ("
1773 << elementType <<
")";
1779 auto *parentOp = (*this)->getParentOp();
1780 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1781 return emitOpError(
"expected single non-empty parent region");
1783 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1786 return emitOpError(
"expected parent op with LinalgOp interface");
1794 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1796 return emitOpError(
"expected parent op with LinalgOp interface");
1797 if (linalgOp.getNumLoops() <= getDim())
1798 return emitOpError(
"expected dim (")
1799 << getDim() <<
") to be lower than the number of loops ("
1800 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
1806 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1808 #define GET_OP_CLASSES
1809 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1811 #define GET_OP_CLASSES
1812 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1829 for (
unsigned i = 0; i < num; ++i)
1836 auto rangeA = llvm::make_range(a.begin(), a.end());
1837 auto rangeB = llvm::make_range(b.begin(), b.end());
1838 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1839 return llvm::to_vector<4>(concatRanges);
1843 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
1845 for (
auto size : memref.getShape())
1852 if (
auto as = memref.getMemorySpace()) {
1853 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
1854 ss <<
"as" << attr.getInt();
1860 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
1863 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
1875 assert(isa<LinalgOp>(op));
1877 std::string fun =
"";
1879 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
1880 fun = stringifyEnum(ufa.getValue()).str() +
"_";
1881 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
1882 fun = stringifyEnum(bfa.getValue()).str() +
"_";
1886 std::replace(name.begin(), name.end(),
'.',
'_');
1887 llvm::raw_string_ostream ss(name);
1891 return std::string();
1894 std::string res = ss.str();
1913 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
1916 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
1927 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
1935 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1942 if (castOp->getBlock() != linalgOp->getBlock())
1949 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
1952 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1958 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
1960 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
1963 outputOperands[resultNumber] = newOperand;
1964 newOperands.append(outputOperands.begin(), outputOperands.end());
1967 linalgOp->result_type_end());
1968 resultTypes[resultNumber] = resultType;
1969 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
1976 results[resultNumber] = castBack;
1988 if (linalgOp.isScalar(&opOperand))
1990 Value src = opOperand.get();
1991 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
1992 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2000 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2001 Value castSource = castOp.getSource();
2002 auto castSourceType =
2003 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2004 if (castSourceType && castSourceType.hasStaticShape())
2005 sourceShape = castSourceType.getShape();
2011 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2012 if (sourceType.isDynamicDim(i))
2014 if (
auto affineDimExpr = sourceMap.getResult(i).dyn_cast<
AffineDimExpr>())
2015 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2025 static void createNewOperandWithStaticSizes(
2029 bool &changeNeeded) {
2031 newOperands.push_back(src);
2032 if (linalgOp.isScalar(opOperand))
2034 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2035 Type resultType = sourceType;
2036 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2037 resultTypes.push_back(resultType);
2041 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2045 bool newOperandNeeded =
false;
2046 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2047 int64_t dimShape = sourceShape[i];
2049 if (!affineExprToSize.contains(dimExpr) ||
2050 !sourceType.isDynamicDim(i)) {
2051 newShape.push_back(dimShape);
2057 newShape.push_back(affineExprToSize[dimExpr]);
2058 newOperandNeeded =
true;
2061 if (newOperandNeeded) {
2062 changeNeeded =
true;
2065 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2067 newOperands[index] = newOperand;
2069 if (linalgOp.isDpsInit(opOperand))
2070 resultTypes.push_back(resultType);
2081 if (!linalgOp.hasTensorSemantics())
2085 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2086 return !map.isProjectedPermutation();
2096 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2103 bool changeNeeded =
false;
2104 newOperands.reserve(linalgOp->getNumOperands());
2105 resultTypes.reserve(linalgOp.getNumDpsInits());
2108 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2109 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2110 affineExprToSize, linalgOp, newOperands,
2111 resultTypes, changeNeeded);
2120 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2123 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2124 Value newResult = std::get<1>(it);
2125 Value oldResult = std::get<0>(it);
2128 replacements.push_back(
2129 (newType != oldType)
2130 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2133 rewriter.
replaceOp(linalgOp, replacements);
2147 void LinalgDialect::getCanonicalizationPatterns(
2149 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2150 InferStaticShapeOfOperands>(getContext());
2156 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const OpOperandVector &inputOperands, const OpOperandVector &outputOperands)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
MPInt ceil(const Fraction &f)
MPInt floor(const Fraction &f)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpOperand vector that implicitly converts to a Value vector.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.