MLIR 22.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Attributes.h"
29#include "mlir/IR/Builders.h"
38
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SetOperations.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/StringSet.h"
44#include "llvm/ADT/TypeSwitch.h"
45#include "llvm/Support/FormatVariadic.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
48#include "llvm/Support/MathExtras.h"
49#include "llvm/Support/raw_ostream.h"
50#include <cassert>
51#include <optional>
52
53using namespace mlir;
54using namespace mlir::linalg;
55
56/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 int64_t dim) {
59 auto type = cast<ShapedType>(v.getType());
60 if (!type.isDynamicDim(dim))
61 return builder.getIndexAttr(type.getDimSize(dim));
62
63 return getAsOpFoldResult(
65 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
67 })
68 .Case<MemRefType>([&](MemRefType t) -> Value {
69 return memref::DimOp::create(builder, loc, v, dim);
70 }));
71}
72
73/// Returns a memref.subview or a tensor.extract_slice based on the type of the
74/// `source`.
78 ArrayRef<OpFoldResult> strides) {
80 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
81 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
82 strides);
83 })
84 .Case<MemRefType>([&](MemRefType type) -> Operation * {
85 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
86 strides);
87 })
88 .Default([&](Type t) -> Operation * { return nullptr; });
89}
90
91//===----------------------------------------------------------------------===//
92// Helper functions
93//===----------------------------------------------------------------------===//
94
96 int64_t dim) {
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
98 return b.createOrFold<memref::DimOp>(loc, source, dim);
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
100 return b.createOrFold<tensor::DimOp>(loc, source, dim);
101 llvm_unreachable("Expected MemRefType or TensorType");
102}
103
105 int64_t dim) {
106 auto shapedType = llvm::cast<ShapedType>(source.getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
108 return createOrFoldDimOp(b, loc, source, dim);
109 return b.getIndexAttr(shapedType.getDimSize(dim));
110}
111
112//===----------------------------------------------------------------------===//
113// Support for named Linalg ops defined in ods-gen.
114//===----------------------------------------------------------------------===//
115
119
120/// Fills the region of a structured operation using the provided
121/// `regionBuilder`. The method is used by both named structured ops created by
122/// ods-gen and by manually defined C++ ops. It is called by both builders and
123/// parsers and creates a block with arguments corresponding to the elemental
124/// types of `inputTypes` and `outputTypes`.
125static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126 TypeRange inputTypes, TypeRange outputTypes,
129 RegionBuilderFn regionBuilder) {
130 SmallVector<Type, 8> argTypes;
132 for (auto containers : {inputTypes, outputTypes}) {
133 for (auto t : containers) {
134 argTypes.push_back(
135 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136
137 // TODO: Pass in a proper location here.
138 argLocs.push_back(opBuilder.getUnknownLoc());
139 }
140 }
141
142 // RAII.
143 OpBuilder::InsertionGuard guard(opBuilder);
144 Block *body =
145 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146
147 opBuilder.setInsertionPointToStart(body);
148 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149 regionBuilder(b, *body, attrs, emitError);
150
151 // indexing_maps is an auto-generated method.
152
153 // iterator_types is an auto-generated method.
154}
155
156/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157/// The result types are derived automatically if `resultTensorTypes` is none.
158/// The body of the operation is filled using `regionBuilder`. All ods-gen
159/// created structured operations use the method to implement their builders.
161 std::optional<TypeRange> resultTensorTypes,
162 ValueRange inputs, ValueRange outputs,
163 ArrayRef<NamedAttribute> attributes,
164 RegionBuilderFn regionBuilder) {
165 // Derive the result types if needed.
166 SmallVector<Type> derivedResultTypes =
167 resultTensorTypes.value_or(TypeRange());
168 if (!resultTensorTypes)
169 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
171
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
175
176 state.addAttributes(attributes);
177 state.addAttribute(
178 "operandSegmentSizes",
179 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180 static_cast<int32_t>(outputs.size())}));
181
182 // Create and fill the region of the structured operation.
183 Region &region = *state.addRegion();
184 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185 state.attributes.getAttrs(), /*emitError=*/{},
186 regionBuilder);
187}
188
190 std::optional<TypeRange> resultTensorTypes,
191 ValueRange inputs, ValueRange outputs,
192 ArrayRef<NamedAttribute> attributes,
193 RegionBuilderFn regionBuilder,
194 ArrayRef<AffineMap> indexingMaps) {
195 // Initialize indexingMaps attribute, for MatmulOp.
196 SmallVector<Attribute, 3> indexingMapsAttrVal;
197 indexingMapsAttrVal =
198 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
199 return AffineMapAttr::get(map);
200 });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
204}
205
207 std::optional<TypeRange> resultTensorTypes,
208 ValueRange inputs, ValueRange outputs,
209 ArrayRef<NamedAttribute> attributes,
210 RegionBuilderFn regionBuilder,
211 ArrayRef<AffineMap> indexingMaps) {
212 // Initialize indexingMaps attribute, for BatchMatmulOp.
213 SmallVector<Attribute, 4> indexingMapsAttrVal;
214 indexingMapsAttrVal =
215 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216 return AffineMapAttr::get(map);
217 });
218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220 attributes, regionBuilder);
221}
222
224 std::optional<TypeRange> resultTensorTypes,
225 ValueRange inputs, ValueRange outputs,
226 ArrayRef<NamedAttribute> attributes,
227 RegionBuilderFn regionBuilder,
228 ArrayRef<AffineMap> indexingMaps) {
229 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230 SmallVector<Attribute, 4> indexingMapsAttrVal;
231 indexingMapsAttrVal =
232 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
233 return AffineMapAttr::get(map);
234 });
235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237 attributes, regionBuilder);
238}
239
240/// Common parsing used for both named structured ops created by ods-gen and by
241/// manually defined C++ ops. Does not handle regions.
242static ParseResult
244 SmallVectorImpl<Type> &inputTypes,
245 SmallVectorImpl<Type> &outputTypes,
246 bool addOperandSegmentSizes = true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249 outputsOperands;
250
251 if (succeeded(parser.parseOptionalLess())) {
252 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
253 return failure();
254 }
255 attrsLoc = parser.getCurrentLocation();
256 if (parser.parseOptionalAttrDict(result.attributes))
257 return failure();
258
259 if (succeeded(parser.parseOptionalKeyword("ins"))) {
260 if (parser.parseLParen())
261 return failure();
262
263 inputsOperandsLoc = parser.getCurrentLocation();
264 if (parser.parseOperandList(inputsOperands) ||
265 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
266 return failure();
267 }
268
269 if (succeeded(parser.parseOptionalKeyword("outs"))) {
270 outputsOperandsLoc = parser.getCurrentLocation();
271 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
272 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
273 return failure();
274 }
275
276 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
277 result.operands) ||
278 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
279 result.operands))
280 return failure();
281
282 if (addOperandSegmentSizes) {
283 // This is a bit complex because we're trying to be backward compatible with
284 // operation syntax that mix the inherent attributes and the discardable
285 // ones in the same dictionary. If the properties are used, we append the
286 // operandSegmentSizes there directly. Otherwise we append it to the
287 // discardable attributes dictionary where it is handled by the generic
288 // Operation::create(...) method.
289 if (result.propertiesAttr) {
290 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291 attrs.append("operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
295 result.propertiesAttr = attrs.getDictionary(parser.getContext());
296 } else {
297 result.addAttribute("operandSegmentSizes",
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
301 }
302 }
303 if (!result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
306 if (info) {
307 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 << "'" << result.name.getStringRef() << "' op ";
310 })))
311 return failure();
312 }
313 }
314 return success();
315}
316
318 ValueRange outputs) {
319 if (!inputs.empty())
320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321 if (!outputs.empty())
322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323}
324
325//===----------------------------------------------------------------------===//
326// Specific parsing and printing for named structured ops created by ods-gen.
327//===----------------------------------------------------------------------===//
328
330 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332 RegionBuilderFn regionBuilder, SMLoc loc) {
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334 return parser.emitError(
335 parser.getCurrentLocation(),
336 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
339 }
340
341 OpBuilder opBuilder(parser.getContext());
342 ParseResult result = success();
344 opBuilder, region, inputTypes, outputTypes, attrs,
345 [&]() {
346 result = failure();
347 return parser.emitError(loc);
348 },
349 regionBuilder);
350 return result;
351}
352
353static ParseResult
355 SmallVectorImpl<Type> &resultTypes) {
356 if (parser.parseOptionalArrowTypeList(resultTypes))
357 return failure();
358 return success();
359}
360
361static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
363 unsigned numRegionArgs,
364 RegionBuilderFn regionBuilder) {
365 // TODO: Enable when ods-gen supports captures.
366 SmallVector<Type, 1> inputTypes, outputTypes;
367 SMLoc loc = parser.getCurrentLocation();
368 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
369 return failure();
370
371 // Parse optional attributes.
372 if (parser.parseOptionalAttrDict(result.attributes))
373 return failure();
374
375 // TODO: consider merging results parsing into region parsing.
376 // Need to wait for declarative assembly resolution to decide.
377 SmallVector<Type, 1> outputTensorsTypes;
378 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
379 return failure();
380 result.addTypes(outputTensorsTypes);
381
382 std::unique_ptr<Region> region = std::make_unique<Region>();
383 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
384 outputTypes, result.attributes.getAttrs(),
385 regionBuilder, loc))
386 return failure();
387 result.addRegion(std::move(region));
388
389 return success();
390}
391
393 TypeRange resultTypes) {
394 if (resultTypes.empty())
395 return;
396 p.printOptionalArrowTypeList(resultTypes);
397}
398
400 ValueRange inputs, ValueRange outputs,
401 ArrayRef<StringRef> elidedAttrs = {}) {
402 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
403
404 // Printing is shared with generic ops, except for the region and
405 // attributes.
406 printCommonStructuredOpParts(p, inputs, outputs);
407
408 // Results printing.
410
411 // Region is elided.
412}
413
414//===----------------------------------------------------------------------===//
415// Region builder helper.
416// TODO: Move this to a utility library.
417// The public methods on this class are referenced directly from generated code.
418// Helper build the unary, binary, and type conversion functions defined by the
419// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
420// class.
421//
422// Implementations of the math functions must be polymorphic over numeric types,
423// internally performing necessary casts. If the function application makes no
424// sense, then the only recourse is to assert and return nullptr. This can be
425// extended later if it becomes possible to fail construction of the region. The
426// invariant should be enforced at a higher level.
427//
428// TODO: These helpers are currently type polymorphic over the class of integer
429// and floating point types, but they will not internally cast within bit
430// widths of a class (mixed precision such as i8->i32) or across classes
431// (i.e. mixed float and integer). Many such combinations are ambiguous or need
432// to be handled with care and work is being considered to extend the op
433// language to make such cases explicit. In the mean-time, violating this will
434// fail verification, which is deemed acceptable.
435//===----------------------------------------------------------------------===//
436
437namespace {
438
439class RegionBuilderHelper {
440public:
441 RegionBuilderHelper(OpBuilder &builder, Block &block)
442 : builder(builder), block(block) {}
443
444 // Build the unary functions defined by OpDSL.
445 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
446 function_ref<InFlightDiagnostic()> emitError = {}) {
447 if (!isFloatingPoint(arg)) {
448 if (emitError) {
449 emitError() << "unsupported non numeric type";
450 return nullptr;
451 }
452 llvm_unreachable("unsupported non numeric type");
453 }
454 OpBuilder::InsertionGuard g(builder);
455 builder.setInsertionPointToEnd(&block);
456 switch (unaryFn) {
457 case UnaryFn::exp:
458 return math::ExpOp::create(builder, arg.getLoc(), arg);
459 case UnaryFn::log:
460 return math::LogOp::create(builder, arg.getLoc(), arg);
461 case UnaryFn::abs:
462 return math::AbsFOp::create(builder, arg.getLoc(), arg);
463 case UnaryFn::ceil:
464 return math::CeilOp::create(builder, arg.getLoc(), arg);
465 case UnaryFn::floor:
466 return math::FloorOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::negf:
468 return arith::NegFOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::reciprocal: {
470 Attribute oneAttr = builder.getOneAttr(arg.getType());
471 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
474 }
475 case UnaryFn::round:
476 return math::RoundOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::sqrt:
478 return math::SqrtOp::create(builder, arg.getLoc(), arg);
479 case UnaryFn::rsqrt:
480 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
483 case UnaryFn::tanh:
484 return math::TanhOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::erf:
486 return math::ErfOp::create(builder, arg.getLoc(), arg);
487 }
488 if (emitError) {
489 emitError() << "unsupported unary function";
490 return nullptr;
491 }
492 llvm_unreachable("unsupported unary function");
493 }
494
495 // Build the binary functions defined by OpDSL.
496 // If emitError is provided, an error will be emitted if the operation is not
497 // supported and a nullptr will be returned, otherwise an assertion will be
498 // raised.
499 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
500 function_ref<InFlightDiagnostic()> emitError = {}) {
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
504 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
505 arg1.getType().getIntOrFloatBitWidth() == 1;
506 if (!allComplex && !allFloatingPoint && !allInteger) {
507 if (emitError) {
508 emitError()
509 << "Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
511 << arg0.getType() << " and " << arg1.getType();
512 return nullptr;
513 }
514 llvm_unreachable("unsupported non numeric type");
515 }
516 OpBuilder::InsertionGuard g(builder);
517 builder.setInsertionPointToEnd(&block);
518 switch (binaryFn) {
519 case BinaryFn::add:
520 if (allComplex)
521 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
524 if (allBool)
525 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
527 case BinaryFn::sub:
528 if (allComplex)
529 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool) {
533 if (emitError) {
534 emitError() << "unsupported operation: sub with bools";
535 return nullptr;
536 }
537 llvm_unreachable("unsupported operation: sub with bools");
538 }
539 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
540 case BinaryFn::mul:
541 if (allComplex)
542 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
545 if (allBool)
546 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::div:
549 if (allComplex)
550 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool) {
554 if (emitError) {
555 emitError() << "unsupported operation: div with bools";
556 return nullptr;
557 }
558 llvm_unreachable("unsupported operation: div with bools");
559 }
560 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
563 if (emitError) {
564 emitError() << "unsupported operation: unsigned div not on uint";
565 return nullptr;
566 }
567 llvm_unreachable("unsupported operation: unsigned div not on uint");
568 }
569 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
571 assert(!allComplex);
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
576 assert(!allComplex);
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
581 assert(!allComplex);
582 if (allFloatingPoint)
583 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
584 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
585 case BinaryFn::min_unsigned:
586 assert(!allComplex);
587 if (allFloatingPoint)
588 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
589 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
590 case BinaryFn::powf:
591 assert(allFloatingPoint);
592 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
593 }
594 if (emitError) {
595 emitError() << "unsupported binary function";
596 return nullptr;
597 }
598 llvm_unreachable("unsupported binary function");
599 }
600
601 // Build the ternary functions defined by OpDSL.
602 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
603 function_ref<InFlightDiagnostic()> emitError = {}) {
604 bool headBool =
605 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
606 bool tailFloatingPoint =
607 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
608 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
609 OpBuilder::InsertionGuard g(builder);
610 builder.setInsertionPointToEnd(&block);
611 switch (ternaryFn) {
612 case TernaryFn::select:
613 if (!headBool && !(tailFloatingPoint || tailInteger))
614 llvm_unreachable("unsupported non numeric type");
615 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
616 }
617 if (emitError) {
618 emitError() << "unsupported ternary function";
619 return nullptr;
620 }
621 llvm_unreachable("unsupported ternary function");
622 }
623
624 // Build the type functions defined by OpDSL.
625 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
626 function_ref<InFlightDiagnostic()> emitError = {}) {
627 switch (typeFn) {
628 case TypeFn::cast_signed:
629 return cast(toType, operand, false);
630 case TypeFn::cast_unsigned:
631 return cast(toType, operand, true);
632 }
633 if (emitError) {
634 emitError() << "unsupported type conversion function";
635 return nullptr;
636 }
637 llvm_unreachable("unsupported type conversion function");
638 }
639
640 void yieldOutputs(ValueRange values) {
641 OpBuilder::InsertionGuard g(builder);
642 builder.setInsertionPointToEnd(&block);
643 Location loc = builder.getUnknownLoc();
644 YieldOp::create(builder, loc, values);
645 }
646
647 Value constant(const std::string &value) {
648 OpBuilder::InsertionGuard g(builder);
649 builder.setInsertionPointToEnd(&block);
650 Location loc = builder.getUnknownLoc();
651 Attribute valueAttr = parseAttribute(value, builder.getContext());
652 return arith::ConstantOp::create(builder, loc,
653 ::cast<TypedAttr>(valueAttr));
654 }
655
656 Value index(int64_t dim) {
657 OpBuilder::InsertionGuard g(builder);
658 builder.setInsertionPointToEnd(&block);
659 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
660 }
661
662 Type getIntegerType(unsigned width) {
663 return IntegerType::get(builder.getContext(), width);
664 }
665
666 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
667 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
668
669private:
670 // Generates operations to cast the given operand to a specified type.
671 // If the cast cannot be performed, a warning will be issued and the
672 // operand returned as-is (which will presumably yield a verification
673 // issue downstream).
674 Value cast(Type toType, Value operand, bool isUnsignedCast) {
675 OpBuilder::InsertionGuard g(builder);
676 builder.setInsertionPointToEnd(&block);
677 auto loc = operand.getLoc();
678 if (isa<UnknownLoc>(loc)) {
679 if (operand.getDefiningOp())
680 loc = operand.getDefiningOp()->getLoc();
681 else if (operand.getParentBlock() &&
682 operand.getParentBlock()->getParentOp())
683 loc = operand.getParentBlock()->getParentOp()->getLoc();
684 }
685 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
686 }
687
688 bool isComplex(Value value) {
689 return llvm::isa<ComplexType>(value.getType());
690 }
691 bool isFloatingPoint(Value value) {
692 return llvm::isa<FloatType>(value.getType());
693 }
694 bool isInteger(Value value) {
695 return llvm::isa<IntegerType>(value.getType());
696 }
697
698 OpBuilder &builder;
699 Block &block;
700};
701
702} // namespace
703
704//===----------------------------------------------------------------------===//
705// CopyOp
706//===----------------------------------------------------------------------===//
707
708namespace {
709
710struct EraseSelfCopy : OpRewritePattern<CopyOp> {
711 using OpRewritePattern<CopyOp>::OpRewritePattern;
712 LogicalResult matchAndRewrite(CopyOp copyOp,
713 PatternRewriter &rewriter) const override {
714 if (copyOp.getInputs() != copyOp.getOutputs())
715 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
716 if (copyOp.hasPureBufferSemantics())
717 rewriter.eraseOp(copyOp);
718 else
719 rewriter.replaceOp(copyOp, copyOp.getInputs());
720
721 return success();
722 }
723};
724
725} // namespace
726
727void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
728 MLIRContext *context) {
729 results.add<EraseSelfCopy>(context);
730}
731
732//===----------------------------------------------------------------------===//
733// FillOp
734//===----------------------------------------------------------------------===//
735
736namespace {
737
738/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
739///
740/// For such op chains, we can create new linalg.fill ops with the result
741/// type of the tensor.expand/collapse_shape op.
742template <typename TensorReshapeOp>
743struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
744 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
745 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
746 PatternRewriter &rewriter) const override {
747 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
748 if (!oldFill)
749 return failure();
750
751 Location loc = oldFill.getLoc();
752 TensorReshapeOp newInit;
753 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
754
755 newInit = TensorReshapeOp::create(
756 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
757 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
758 reshapeOp.getStaticOutputShape());
759 } else {
760 newInit = TensorReshapeOp::create(
761 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
762 reshapeOp.getReassociation());
763 }
764 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
765 ValueRange{newInit});
766 return success();
767 }
768};
769
770/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
771/// filling value are the same.
772struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
774
775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
776 PatternRewriter &rewriter) const override {
777 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
778 if (!fillOp)
779 return failure();
780
781 // We can only fold if the padding value is the same as the original
782 // filling value.
783 Value padValue = padOp.getConstantPaddingValue();
784 if (!padValue || fillOp.value() != padValue)
785 return failure();
786
787 ReifiedRankedShapedTypeDims reifiedShape;
788 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
789 return rewriter.notifyMatchFailure(
790 padOp, "failed to reify tensor.pad op result shape");
791
792 auto emptyTensor =
793 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
794 padOp.getResultType().getElementType());
795 Value replacement =
796 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
797 ValueRange{emptyTensor})
798 .getResult(0);
799 if (replacement.getType() != padOp.getResultType()) {
800 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
801 padOp.getResultType(), replacement);
802 }
803 rewriter.replaceOp(padOp, replacement);
804 return success();
805 }
806};
807
808/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
809/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
810/// filling value are the same.
811struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
813
814 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
815 PatternRewriter &rewriter) const override {
816 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
817 if (!srcPadOp)
818 return failure();
819
820 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
821 return failure();
822
823 // Walk back the tensor.insert_slice chain and find the first destination
824 // value at the start of the chain.
825 Value firstDest = insertOp.getDest();
826 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
827 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
828 return failure();
829
830 // Make sure the range of values accessed are disjoint. Without this, we
831 // cannot fold tensor.pad away.
832 bool disjoint = false;
833 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
834 // If the dimension has dynamic offset/size, we cannot guarantee
835 // disjoint. So just skip it.
836 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
837 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
838 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
839 continue;
840
841 // Get the range start and end, inclusively for both.
842 int64_t prevStart = prevOp.getStaticOffset(i);
843 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
844 prevOp.getStaticStride(i);
845 int64_t nextStart = insertOp.getStaticOffset(i);
846 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
847 insertOp.getStaticStride(i);
848 if (prevEnd < nextStart || nextEnd < prevStart) {
849 disjoint = true;
850 break;
851 }
852 }
853
854 if (!disjoint)
855 break;
856 firstDest = prevOp.getDest();
857 }
858
859 // Check whether the first destination is a fill op. For overlapped cases,
860 // this also cannot be true.
861 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
862 if (!dstFillOp)
863 return failure();
864
865 // We can only fold if the padding value is the same as the original
866 // filling value.
867 Value padValue = srcPadOp.getConstantPaddingValue();
868 if (!padValue || dstFillOp.value() != padValue)
869 return failure();
870
871 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
872 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
873
874 Location loc = insertOp.getLoc();
875 MLIRContext *context = getContext();
876
877 AffineExpr sym0, sym1;
878 bindSymbols(context, sym0, sym1);
879 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
880
881 // Calculate the new offsets for the insert. It should be the old offsets
882 // plus low padding sizes.
883 SmallVector<OpFoldResult, 4> newOffsets;
884 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
885 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
886 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
887 }
888
889 RankedTensorType srcPadType = srcPadOp.getSourceType();
890 SmallVector<OpFoldResult, 4> newSizes;
891 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
892 if (srcPadType.isDynamicDim(i)) {
893 newSizes.push_back(
894 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
895 .getResult());
896 } else {
897 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
898 }
899 }
900
901 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
902 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
903 newSizes, insertOp.getMixedStrides());
904 return success();
905 }
906};
907
908/// Fold tensor.extract(linalg.fill(<input>)) into <input>
909struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
910public:
911 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
912
913 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
914 PatternRewriter &rewriter) const override {
915 // See if tensor input of tensor.extract op is the result of a linalg.fill
916 // op.
917 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
918 if (!fillOp)
919 return failure();
920
921 // Get scalar input operand of linalg.fill op.
922 Value extractedScalar = fillOp.getInputs()[0];
923
924 // Replace tensor.extract op with scalar value used to fill the tensor.
925 rewriter.replaceOp(extractOp, extractedScalar);
926 return success();
927 }
928};
929
930/// Folds pack(fill) into a single fill op if
931/// 1. The pack op does not have padding value, or
932/// 2. The filled value and padding value are the same.
933static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
934 linalg::PackOp packOp) {
935 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
936 if (!fillOp)
937 return failure();
938
939 if (auto paddingValue = packOp.getPaddingValue())
940 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
941 return failure();
942
943 Value packOpDest = packOp.getDest();
944 if (!packOpDest.hasOneUse())
945 return failure();
946
947 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
948 packOp.getDest());
949}
950
951/// Wrapper pattern that applies foldFillPackIntoFillOp method.
952struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
953public:
954 FoldFillWithPack(MLIRContext *context)
955 : OpRewritePattern<linalg::PackOp>(context) {}
956
957 LogicalResult matchAndRewrite(linalg::PackOp packOp,
958 PatternRewriter &rewriter) const override {
959 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
960 if (failed(fillOp))
961 return failure();
962 rewriter.replaceOp(packOp, fillOp.value().result());
963 return success();
964 }
965};
966
967/// Fold fill with copy.
968struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
969 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
970
971 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
972 PatternRewriter &rewriter) const override {
973 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
974 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
975 fillOp.getInputs(),
976 copyOp.getOutputs());
977 return success();
978 }
979 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
980 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
981 fillOp.getOutputs());
982 return success();
983 }
984 return failure();
985 }
986};
987
988/// Fold fill with transpose.
989struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
990 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
991
992 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
993 PatternRewriter &rewriter) const override {
994 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
995 rewriter.replaceOpWithNewOp<FillOp>(
996 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
997 transposeOp.getDpsInitOperand(0)->get());
998 return success();
999 }
1000 return failure();
1001 }
1002};
1003
1004/// Fold a concat with all elements being fills of the same value
1005/// into a fill of the concat result shape.
1006struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1008
1009 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1010 PatternRewriter &rewriter) const override {
1011 auto concatOperands = concatOp.getInputs();
1012 if (concatOperands.empty()) {
1013 return failure();
1014 }
1015
1016 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1017 if (!firstFillOp) {
1018 return failure();
1019 }
1020 // Prefetch the fill value.
1021 OpFoldResult firstFillVal =
1022 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1023 // Collect all the outs values for the fill operations.
1024 SmallVector<Value> allOuts;
1025 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1026
1027 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1028 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1029 if (!fillOp) {
1030 return false;
1031 }
1032
1033 OpFoldResult fillVal =
1034 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1035 if (fillVal != firstFillVal)
1036 return false;
1037
1038 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1039 return true;
1040 };
1041 if (!llvm::all_of(concatOperands.drop_front(),
1042 isDefinedByCompatibleFillOp)) {
1043 return rewriter.notifyMatchFailure(
1044 concatOp, "not all operands are defined by a compatible fill op");
1045 }
1046
1047 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1048 concatOp.getDim(), allOuts);
1049 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1050 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1051 return success();
1052 }
1053};
1054
1055} // namespace
1056
1057void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058 MLIRContext *context) {
1059 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1060 FoldFillWithPack, FoldFillWithPad,
1061 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1062 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1063 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1064}
1065
1066//===----------------------------------------------------------------------===//
1067// GenericOp
1068//===----------------------------------------------------------------------===//
1069
1071 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1072 ValueRange outputs,
1073 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1074 SmallVector<Type, 4> blockArgTypes;
1075 SmallVector<Location, 4> blockArgLocs;
1076 for (ValueRange container : {inputs, outputs}) {
1077 for (Value v : container) {
1078 Type t = v.getType();
1079 blockArgTypes.push_back(
1080 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1081 blockArgLocs.push_back(v.getLoc());
1082 }
1083 }
1084
1085 OpBuilder::InsertionGuard guard(builder);
1086 Block *bodyBlock =
1087 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1088 bodyBuild(builder, loc, bodyBlock->getArguments());
1089}
1090
1091void GenericOp::getAsmBlockArgumentNames(Region &region,
1092 OpAsmSetValueNameFn setNameFn) {
1093 for (Value v : getRegionInputArgs())
1094 setNameFn(v, "in");
1095 for (Value v : getRegionOutputArgs())
1096 setNameFn(v, "out");
1097}
1098
1099void GenericOp::build(
1100 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1101 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1102 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1103 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1104 ArrayRef<NamedAttribute> attributes) {
1105 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1106 iteratorTypes, doc, libraryCall);
1107 result.addAttributes(attributes);
1108 if (bodyBuild)
1109 buildGenericRegion(builder, result.location, *result.regions.front(),
1110 inputs, outputs, bodyBuild);
1111}
1112
1113void GenericOp::build(
1114 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1115 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1116 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1117 StringRef libraryCall,
1118 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1119 ArrayRef<NamedAttribute> attributes) {
1120 build(builder, result, resultTensorTypes, inputs, outputs,
1121 builder.getAffineMapArrayAttr(indexingMaps),
1122 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1123 iteratorTypes,
1124 [&](utils::IteratorType iter) -> mlir::Attribute {
1125 return IteratorTypeAttr::get(builder.getContext(), iter);
1126 }))),
1127 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1128 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1129 bodyBuild, attributes);
1130}
1131
1132void GenericOp::build(
1133 OpBuilder &builder, OperationState &result, ValueRange inputs,
1134 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1135 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1136 StringRef libraryCall,
1137 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1138 ArrayRef<NamedAttribute> attributes) {
1139 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1140 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes,
1147 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1148 ArrayRef<NamedAttribute> attributes) {
1149 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1150 /*doc=*/"",
1151 /*libraryCall=*/"", bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1156 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1161 iteratorTypes,
1162 /*doc=*/"",
1163 /*libraryCall=*/"", bodyBuild, attributes);
1164}
1165
1166void GenericOp::print(OpAsmPrinter &p) {
1167 p << " ";
1168
1169 // Print extra attributes.
1170 auto genericAttrNames = linalgTraitAttrNames();
1171
1172 llvm::StringSet<> genericAttrNamesSet;
1173 genericAttrNamesSet.insert_range(genericAttrNames);
1174 SmallVector<NamedAttribute, 8> genericAttrs;
1175 for (auto attr : (*this)->getAttrs()) {
1176 if (attr.getName() == getIteratorTypesAttrName()) {
1177 auto iteratorTypes =
1178 llvm::cast<ArrayAttr>(attr.getValue())
1179 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1180 // Convert IteratorType enums into the string representation. This is
1181 // needed, because tests still use the old format when 'iterator_types'
1182 // attribute is represented as an array of strings.
1183 // TODO: Remove this conversion once tests are fixed.
1184 SmallVector<Attribute> iteratorTypeNames =
1185 llvm::to_vector(llvm::map_range(
1186 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1187 return StringAttr::get(getContext(), stringifyIteratorType(t));
1188 }));
1189
1190 genericAttrs.emplace_back(
1191 getIteratorTypesAttrName(),
1192 ArrayAttr::get(getContext(), iteratorTypeNames));
1193 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1194 genericAttrs.push_back(attr);
1195 }
1196 }
1197 if (!genericAttrs.empty()) {
1198 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1199 p << genericDictAttr;
1200 }
1201
1202 // Printing is shared with named ops, except for the region and attributes
1203 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1204
1205 genericAttrNames.push_back("operandSegmentSizes");
1206 genericAttrNamesSet.insert(genericAttrNames.back());
1207
1208 bool hasExtraAttrs = false;
1209 for (NamedAttribute n : (*this)->getAttrs()) {
1210 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1211 break;
1212 }
1213 if (hasExtraAttrs) {
1214 p << " attrs = ";
1215 p.printOptionalAttrDict((*this)->getAttrs(),
1216 /*elidedAttrs=*/genericAttrNames);
1217 }
1218
1219 // Print region.
1220 if (!getRegion().empty()) {
1221 p << ' ';
1222 p.printRegion(getRegion());
1223 }
1224
1225 // Print results.
1226 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1227}
1228
1229ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1230 DictionaryAttr dictAttr;
1231 // Parse the core linalg traits that must check into a dictAttr.
1232 // The name is unimportant as we will overwrite result.attributes.
1233 // The core linalg traits must contain the information necessary to pass the
1234 // verifier.
1235 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1236 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1237 return failure();
1238 result.attributes.assign(dictAttr.getValue().begin(),
1239 dictAttr.getValue().end());
1240
1241 // Convert array of string into an array of IteratorType enums. This is
1242 // needed, because tests still use the old format when 'iterator_types'
1243 // attribute is represented as an array of strings.
1244 // TODO: Remove this conversion once tests are fixed.
1245 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1246 result.attributes.get(getIteratorTypesAttrName(result.name)));
1247 if (!iteratorTypes) {
1248 return parser.emitError(attributeLocation)
1249 << "expected " << getIteratorTypesAttrName(result.name)
1250 << " array attribute";
1251 }
1252
1253 SmallVector<Attribute> iteratorTypeAttrs;
1254
1255 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1256 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1257 if (!maybeIteratorType.has_value())
1258 return parser.emitError(parser.getCurrentLocation())
1259 << "unexpected iterator_type (" << s << ")";
1260
1261 iteratorTypeAttrs.push_back(
1262 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1263 }
1264 result.attributes.set(getIteratorTypesAttrName(result.name),
1265 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1266
1267 // Parsing is shared with named ops, except for the region.
1268 SmallVector<Type, 1> inputTypes, outputTypes;
1269 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1270 return failure();
1271
1272 // Optional attributes may be added.
1273 if (succeeded(parser.parseOptionalKeyword("attrs")))
1274 if (failed(parser.parseEqual()) ||
1275 failed(parser.parseOptionalAttrDict(result.attributes)))
1276 return failure();
1277
1278 std::unique_ptr<Region> region = std::make_unique<Region>();
1279 if (parser.parseRegion(*region, {}))
1280 return failure();
1281 result.addRegion(std::move(region));
1282
1283 // Generic ops may specify that a subset of its outputs are tensors. Such
1284 // outputs are specified in the result type.
1285 // TODO: may need to move output parsing before region parsing.
1286 // Need to wait for declarative assembly resolution to decide.
1287 SmallVector<Type, 1> outputTensorsTypes;
1288 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1289 return failure();
1290 result.addTypes(outputTensorsTypes);
1291
1292 return success();
1293}
1294
1297 &effects,
1298 LinalgOp linalgOp) {
1299 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1300 if (!llvm::isa<MemRefType>(operand.getType()))
1301 continue;
1302 effects.emplace_back(
1303 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1304 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1305 }
1306
1307 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1308 if (!llvm::isa<MemRefType>(operand.get().getType()))
1309 continue;
1310 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1311 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1312 /*effectOnFullRegion=*/true,
1314 }
1315 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1316 /*effectOnFullRegion=*/true,
1318 }
1319}
1320
1321void GenericOp::getEffects(
1322 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1323 &effects) {
1324 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1325}
1326
1329 // Operands with value semantics are speculatable, while operands with memory
1330 // semantics are not.
1331 if (!linalgOp.hasPureTensorSemantics())
1333 // The body of the op can still have speculation in its region.
1335}
1336
1337Speculation::Speculatability GenericOp::getSpeculatability() {
1338 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1339}
1340
1341namespace {
1342
1343/// Remove linalg operations that are just copying the values from inputs to
1344/// results. In the memref case, the operation must be copying to and from the
1345/// same value. Requirements are:
1346/// 1) All iterator types are parallel
1347/// 2) The body contains just a yield operation with the yielded values being
1348/// the arguments corresponding to the operands.
1349template <typename OpTy>
1350struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1351 using OpRewritePattern<OpTy>::OpRewritePattern;
1352
1353 LogicalResult matchAndRewrite(OpTy linalgOp,
1354 PatternRewriter &rewriter) const override {
1355 // All indexing maps must be equal. It follows that they are permutations.
1356 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1357 return failure();
1358
1359 // Check that the body of the linalg operation is just a linalg.yield
1360 // operation.
1361 Block &body = linalgOp->getRegion(0).front();
1362 if (!llvm::hasSingleElement(body))
1363 return failure();
1364 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1365 if (!yieldOp)
1366 return failure();
1367
1368 // In the buffer case, we need to check exact buffer equality.
1369 if (linalgOp.hasPureBufferSemantics()) {
1370 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1371 linalgOp.getDpsInputOperand(0)->get() !=
1372 linalgOp.getDpsInitOperand(0)->get()) {
1373 return rewriter.notifyMatchFailure(
1374 linalgOp, "expected single input and output to be the same value");
1375 }
1376
1377 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1378 if (!yieldArg || yieldArg.getOwner() != &body) {
1379 return rewriter.notifyMatchFailure(linalgOp,
1380 "cannot fold fill-like op");
1381 }
1382
1383 rewriter.eraseOp(linalgOp);
1384 return success();
1385 }
1386
1387 if (!linalgOp.hasPureTensorSemantics()) {
1388 return rewriter.notifyMatchFailure(
1389 linalgOp, "mixed semantics is not supported yet");
1390 }
1391
1392 // Get the argument number of the returned values. That is the operand
1393 // number to use for replacing uses of this operation.
1394 SmallVector<Value> returnedArgs;
1395 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1396 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1397 if (!yieldArg || yieldArg.getOwner() != &body)
1398 return failure();
1399 unsigned argumentNumber = yieldArg.getArgNumber();
1400 Value returnedArg = linalgOp->getOperand(argumentNumber);
1401 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1402 // The input can have a different type than the result, e.g. a dynamic
1403 // input dimension can be turned into a static output dimension.
1404 Type returnType = returnedArg.getType();
1405 if (returnType != resultType) {
1406 // Distinguish between sparse conversion or dense tensor casting.
1407 // TODO: unify the two ops?
1410 returnedArg = sparse_tensor::ConvertOp::create(
1411 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1412 else {
1413 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1414 resultType))
1415 return failure();
1416 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1417 resultType, returnedArg);
1418 }
1419 }
1420 returnedArgs.push_back(returnedArg);
1421 }
1422
1423 if (returnedArgs.size() != linalgOp->getNumResults())
1424 return failure();
1425 rewriter.replaceOp(linalgOp, returnedArgs);
1426 return success();
1427 }
1428};
1429
1430} // namespace
1431
1432void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1435}
1436
1437LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1438 return memref::foldMemRefCast(*this);
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// MapOp
1443//===----------------------------------------------------------------------===//
1444
1445static ParseResult parseDstStyleOp(
1447 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1448 nullptr) {
1449 // Parse `ins` and `outs`.
1450 SmallVector<Type, 4> inputTypes, outputTypes;
1451 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1452 /*addOperandSegmentSizes=*/false))
1453 return failure();
1454
1455 // Add result types.
1456 for (Type outputType : outputTypes) {
1457 if (llvm::isa<RankedTensorType>(outputType))
1458 result.addTypes(outputType);
1459 }
1460
1461 // Parse required attributes.
1462 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1463 return failure();
1464
1465 // Parse optional attributes.
1466 if (parser.parseOptionalAttrDict(result.attributes))
1467 return failure();
1468 return success();
1469}
1470
1471void MapOp::getAsmBlockArgumentNames(Region &region,
1472 OpAsmSetValueNameFn setNameFn) {
1473 for (Value v : getRegionInputArgs())
1474 setNameFn(v, "in");
1475 for (Value v : getRegionOutputArgs())
1476 setNameFn(v, "init");
1477}
1478
1479void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1480 if (!getResults().empty())
1481 setNameFn(getResults().front(), "mapped");
1482}
1483
1484void MapOp::build(
1485 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1486 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1487 ArrayRef<NamedAttribute> attributes) {
1488 build(builder, result, TypeRange{}, inputs, init);
1489 result.addAttributes(attributes);
1490
1491 // Add output types for `RankedTensorType` output arguments.
1492 Type initType = init.getType();
1493 if (llvm::isa<RankedTensorType>(initType))
1494 result.addTypes(initType);
1495
1496 if (bodyBuild)
1497 buildGenericRegion(builder, result.location, *result.regions.front(),
1498 inputs, /*outputs=*/{init}, bodyBuild);
1499}
1500
1502 const OperationName &payloadOpName,
1503 const NamedAttrList &payloadOpAttrs,
1504 ArrayRef<Value> operands,
1505 bool initFirst = false, bool mapInit = true) {
1506 OpBuilder b(parser.getContext());
1507 Region *body = result.addRegion();
1508 Block &block = body->emplaceBlock();
1509 b.setInsertionPointToStart(&block);
1510 for (auto &operand : operands) {
1511 block.addArgument(
1512 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1513 b.getUnknownLoc());
1514 }
1515 SmallVector<Value> payloadOpOperands;
1516 // If initFirst flag is enabled, we consider init as the first position of
1517 // payload operands.
1518 if (initFirst) {
1519 if (mapInit)
1520 payloadOpOperands.push_back(block.getArguments().back());
1521 for (const auto &arg : block.getArguments().drop_back())
1522 payloadOpOperands.push_back(arg);
1523 } else {
1524 payloadOpOperands = {block.getArguments().begin(),
1525 block.getArguments().end() - int(!mapInit)};
1526 }
1527
1528 Operation *payloadOp = b.create(
1529 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1530 payloadOpOperands,
1531 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1532 .getElementType()},
1533 payloadOpAttrs);
1534 YieldOp::create(b, result.location, payloadOp->getResults());
1535}
1536
1537ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1538 std::optional<OperationName> payloadOpName;
1539 NamedAttrList payloadOpAttrs;
1540 if (succeeded(parser.parseOptionalLBrace())) {
1541 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1542 if (failed(operationName))
1543 return failure();
1544 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1545 return failure();
1546 payloadOpName = operationName.value();
1547 if (parser.parseRBrace())
1548 return failure();
1549 }
1550
1551 if (parseDstStyleOp(parser, result))
1552 return failure();
1553
1554 if (payloadOpName.has_value()) {
1555 if (!result.operands.empty())
1556 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1557 payloadOpAttrs, ArrayRef(result.operands), false,
1558 false);
1559 else
1560 result.addRegion();
1561 } else {
1562 SmallVector<OpAsmParser::Argument> regionArgs;
1563 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1564 /*allowType=*/true, /*allowAttrs=*/true)) {
1565 return failure();
1566 }
1567 Region *body = result.addRegion();
1568 if (parser.parseRegion(*body, regionArgs))
1569 return failure();
1570 }
1571 return success();
1572}
1573
1574static bool canUseShortForm(Block *body, bool initFirst = false,
1575 bool mapInit = true) {
1576 // `intFirst == true` implies that we want to map init arg
1577 if (initFirst && !mapInit)
1578 return false;
1579 // Check if the body can be printed in short form. The following 4 conditions
1580 // must be satisfied:
1581
1582 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1583 if (body->getOperations().size() != 2)
1584 return false;
1585 Operation &payload = body->getOperations().front();
1586
1587 // 2) The payload op must have the same number of operands as the number of
1588 // block arguments.
1589 if (payload.getNumOperands() == 0 ||
1590 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1591 return false;
1592
1593 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1594 // must be the first operand of the payload op, otherwise, the operands
1595 // must match the block arguments in order.
1596 if (initFirst) {
1597 // check init
1598 if (payload.getOperands().back() != body->getArgument(0))
1599 return false;
1600 // check rest
1601 for (const auto &[operand, bbArg] :
1602 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1603 if (bbArg != operand)
1604 return false;
1605 }
1606 } else {
1607 for (const auto &[operand, bbArg] :
1608 llvm::zip(payload.getOperands(),
1609 body->getArguments().drop_back(int(!mapInit)))) {
1610 if (bbArg != operand)
1611 return false;
1612 }
1613 }
1614
1615 // 4) The `yield` operand must be the result of the payload op.
1616 auto yieldOp = cast<YieldOp>(body->getTerminator());
1617 return yieldOp.getNumOperands() == 1 &&
1618 yieldOp.getOperand(0).getDefiningOp() &&
1619 yieldOp.getOperand(0).getDefiningOp() == &payload;
1620}
1621
1622static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1623 SmallVector<StringRef> elidedAttrs;
1624 std::string attrToElide;
1625 p << " { " << payloadOp->getName().getStringRef();
1626 for (const auto &attr : payloadOp->getAttrs()) {
1627 auto fastAttr =
1628 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1629 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1630 attrToElide = attr.getName().str();
1631 elidedAttrs.push_back(attrToElide);
1632 break;
1633 }
1634 }
1635 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1636 p << " }";
1637}
1638
1639void MapOp::print(OpAsmPrinter &p) {
1640 Block *mapper = getBody();
1641 bool useShortForm =
1642 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1643 if (useShortForm) {
1644 printShortForm(p, &mapper->getOperations().front());
1645 }
1646
1647 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1648 p.printOptionalAttrDict((*this)->getAttrs());
1649
1650 if (!useShortForm) {
1651 // Print region if the payload op was not detected.
1652 p.increaseIndent();
1653 p.printNewline();
1654 p << "(";
1655 llvm::interleaveComma(mapper->getArguments(), p,
1656 [&](auto arg) { p.printRegionArgument(arg); });
1657 p << ") ";
1658
1659 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1660 p.decreaseIndent();
1661 }
1662}
1663
1664LogicalResult MapOp::verify() {
1665 auto *bodyBlock = getBody();
1666 auto blockArgs = bodyBlock->getArguments();
1667
1668 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1669 // region.
1670 if (getInputs().size() + 1 != blockArgs.size())
1671 return emitOpError() << "expects number of operands to match the arity of "
1672 "mapper, but got: "
1673 << getInputs().size() + 1 << " and "
1674 << blockArgs.size();
1675
1676 // The parameters of mapper should all match the element type of inputs.
1677 for (const auto &[bbArgType, inputArg] :
1678 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1679 auto inputElemType =
1680 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1681 if (bbArgType != inputElemType) {
1682 return emitOpError() << "expected element type of input " << inputElemType
1683 << " to match bbArg type " << bbArgType;
1684 }
1685 }
1686
1687 // The shape of each input must match the shape of the output.
1688 auto outputShape = getInit().getType().getShape();
1689 for (Type inputArgType : TypeRange{getInputs()}) {
1690 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1691 if (inputElemShape != outputShape) {
1692 return emitOpError() << "expected shape of input (" << inputElemShape
1693 << ") to match shape of output (" << outputShape
1694 << ")";
1695 }
1696 }
1697
1698 return success();
1699}
1700
1701SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1702 int64_t rank = getInit().getType().getRank();
1703 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1704}
1705
1706ArrayAttr MapOp::getIndexingMaps() {
1707 Builder builder(getContext());
1708 int64_t rank = getInit().getType().getRank();
1709 int64_t numIndexingMaps = getOperands().size();
1710 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1711 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1712}
1713
1714void MapOp::getEffects(
1715 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1716 &effects) {
1717 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1718}
1719
1720Speculation::Speculatability MapOp::getSpeculatability() {
1721 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1722}
1723
1724//===----------------------------------------------------------------------===//
1725// ReduceOp
1726//===----------------------------------------------------------------------===//
1727
1728void ReduceOp::getAsmBlockArgumentNames(Region &region,
1729 OpAsmSetValueNameFn setNameFn) {
1730 for (Value v : getRegionInputArgs())
1731 setNameFn(v, "in");
1732 for (Value v : getRegionOutputArgs())
1733 setNameFn(v, "init");
1734}
1735
1736void ReduceOp::getAsmResultNames(
1737 function_ref<void(Value, StringRef)> setNameFn) {
1738 if (!getResults().empty())
1739 setNameFn(getResults().front(), "reduced");
1740}
1741
1742void ReduceOp::build(
1743 OpBuilder &builder, OperationState &result, ValueRange inputs,
1744 ValueRange inits, ArrayRef<int64_t> dimensions,
1745 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1746 ArrayRef<NamedAttribute> attributes) {
1747 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1748 result.addAttributes(attributes);
1749
1750 // Add output types for `RankedTensorType` output arguments.
1751 for (Value init : inits) {
1752 Type initType = init.getType();
1753 if (llvm::isa<RankedTensorType>(initType))
1754 result.addTypes(initType);
1755 }
1756
1757 if (bodyBuild)
1758 buildGenericRegion(builder, result.location, *result.regions.front(),
1759 inputs, inits, bodyBuild);
1760}
1761
1762SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1763 int64_t inputRank =
1764 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1765 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1766 utils::IteratorType::parallel);
1767 for (int64_t reductionDim : getDimensions())
1768 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1769 return iteratorTypes;
1770}
1771
1772ArrayAttr ReduceOp::getIndexingMaps() {
1773 int64_t inputRank =
1774 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1775 SmallVector<AffineMap> affineMaps(
1776 getNumDpsInputs(),
1778 AffineMap resultMap =
1780 .dropResults(getDimensions());
1781 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1782 affineMaps.push_back(resultMap);
1783 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1784}
1785
1786void ReduceOp::getEffects(
1787 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1788 &effects) {
1789 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1790}
1791
1792Speculation::Speculatability ReduceOp::getSpeculatability() {
1793 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1794}
1795
1796static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1797 NamedAttrList &attributes,
1798 StringRef attributeName) {
1799 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1800 return failure();
1801
1802 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1803 return success();
1804}
1805
1806ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1807 std::optional<OperationName> payloadOpName;
1808 NamedAttrList payloadOpAttrs;
1809 if (succeeded(parser.parseOptionalLBrace())) {
1810 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1811 if (failed(operationName))
1812 return failure();
1813 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1814 return failure();
1815 payloadOpName = operationName.value();
1816 if (parser.parseRBrace())
1817 return failure();
1818 }
1819
1820 if (parseDstStyleOp(
1821 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1822 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1823 }))
1824 return failure();
1825
1826 if (payloadOpName.has_value()) {
1827 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1828 ArrayRef(result.operands), /*initFirst=*/true);
1829 } else {
1830 SmallVector<OpAsmParser::Argument> regionArgs;
1831 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1832 /*allowType=*/true, /*allowAttrs=*/true)) {
1833 return failure();
1834 }
1835
1836 Region *body = result.addRegion();
1837 if (parser.parseRegion(*body, regionArgs))
1838 return failure();
1839 }
1840
1841 return success();
1842}
1843
1844static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1845 ArrayRef<int64_t> attributeValue) {
1846 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1847}
1848
1849void ReduceOp::print(OpAsmPrinter &p) {
1850 Block *mapper = getBody();
1851 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1852 if (useShortForm) {
1853 printShortForm(p, &mapper->getOperations().front());
1854 }
1855
1856 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1857 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1858 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1859 if (!useShortForm) {
1860 // Print region if the payload op was not detected.
1861 p.increaseIndent();
1862 p.printNewline();
1863 p << "(";
1864 llvm::interleaveComma(mapper->getArguments(), p,
1865 [&](auto arg) { p.printRegionArgument(arg); });
1866 p << ") ";
1867
1868 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1869 p.decreaseIndent();
1870 }
1871}
1872
1873LogicalResult ReduceOp::verify() {
1874 ArrayRef<int64_t> dimensionsRef = getDimensions();
1875
1876 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1877 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1878 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1879 return emitOpError() << "expects all inputs to have the same shapes. "
1880 "Shape at input-index "
1881 << i
1882 << " is not equal to the shape at input-index 0.";
1883 }
1884 }
1885 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1886 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1887 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1888 return emitOpError() << "expects all outputs to have the same shapes. "
1889 "Shape at output-index "
1890 << i
1891 << " is not equal to the shape at output-index 0.";
1892 }
1893 }
1894 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1895 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1896
1897 DenseSet<int64_t> dimensionsToReduce;
1898 for (int64_t dimension : dimensionsRef) {
1899 if (dimension < 0 || dimension >= inputType.getRank()) {
1900 return emitOpError()
1901 << "dimensions for reduction should be in the range [0, "
1902 << inputType.getRank() - 1 << "].";
1903 }
1904 dimensionsToReduce.insert(dimension);
1905 }
1906
1907 auto inputDims = inputType.getShape();
1908 auto initDims = initType.getShape();
1909
1910 // Input dimensions that will be left after the reduction.
1911 SmallVector<int64_t> reducedInputDims;
1912 for (const auto &en : llvm::enumerate(inputDims)) {
1913 if (!dimensionsToReduce.count(en.index()))
1914 reducedInputDims.push_back(en.value());
1915 }
1916
1917 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1918 return emitOpError() << "number of dimensions after reduction "
1919 << reducedInputDims.size()
1920 << " doesn't match the init rank "
1921 << initType.getRank();
1922 }
1923
1924 if (reducedInputDims != initDims)
1925 return emitOpError() << "init dimensions [" << initDims
1926 << "] doesn't match input dimensions after reduction ["
1927 << reducedInputDims << "]";
1928
1929 Block *block = getBody();
1930 if (block->getNumArguments() != this->getNumOperands())
1931 return emitOpError()
1932 << "mismatching number of operands and block arguments";
1933
1934 // Check that the first block arguments match the element type of the inputs.
1935 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1936 Type inputElementType =
1937 llvm::cast<ShapedType>(input.getType()).getElementType();
1938 if (inputElementType != bbArg.getType())
1939 return emitOpError()
1940 << "input element type " << inputElementType
1941 << " does not match corresponding block argument type "
1942 << bbArg.getType();
1943 }
1944
1945 // Check that the last block arguments match the element type of the outputs.
1946 for (auto [output, bbArg] : llvm::zip(
1947 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1948 auto outputElementType =
1949 llvm::cast<ShapedType>(output.getType()).getElementType();
1950 if (outputElementType != bbArg.getType())
1951 return emitOpError()
1952 << "output element type " << outputElementType
1953 << " does not match corresponding block argument type "
1954 << bbArg.getType();
1955 }
1956 return success();
1957}
1958
1959//===----------------------------------------------------------------------===//
1960// TransposeOp
1961//===----------------------------------------------------------------------===//
1962
1963static void buildIdentityRegion(OpBuilder &builder, Location loc,
1964 Region &region, ValueRange inputs,
1965 ValueRange outputs) {
1966 buildGenericRegion(builder, loc, region, inputs, outputs,
1967 [](OpBuilder &b, Location loc, ValueRange args) {
1968 if (!args.empty())
1969 linalg::YieldOp::create(b, loc, args[0]);
1970 });
1971}
1972
1973void TransposeOp::build(::mlir::OpBuilder &builder,
1974 ::mlir::OperationState &result, Value input, Value init,
1975 DenseI64ArrayAttr permutation,
1976 ArrayRef<NamedAttribute> attributes) {
1977 result.addOperands(input);
1978 result.addOperands(init);
1979 result.addAttribute(getPermutationAttrName(result.name), permutation);
1980 result.addAttributes(attributes);
1981
1982 // Add output types for `RankedTensorType` output arguments.
1983 Type initType = init.getType();
1984 if (llvm::isa<RankedTensorType>(initType))
1985 result.addTypes(initType);
1986
1987 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1988 init);
1989}
1990
1991void TransposeOp::build(::mlir::OpBuilder &builder,
1992 ::mlir::OperationState &result, Value input, Value init,
1993 ArrayRef<int64_t> permutation,
1994 ArrayRef<NamedAttribute> attributes) {
1995 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1996 attributes);
1997}
1998
1999ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2001 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2002 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2003 })))
2004 return failure();
2005
2006 OpBuilder builder(parser.getContext());
2007 buildIdentityRegion(builder, result.location, *result.addRegion(),
2008 /*inputs=*/result.operands,
2009 /*outputs=*/{});
2010 return success();
2011}
2012
2013void TransposeOp::getAsmResultNames(
2014 function_ref<void(Value, StringRef)> setNameFn) {
2015 if (!getResults().empty())
2016 setNameFn(getResults().front(), "transposed");
2017}
2018
2019void TransposeOp::print(OpAsmPrinter &p) {
2020 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2021 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2022 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2023}
2024
2025LogicalResult TransposeOp::verify() {
2026 ArrayRef<int64_t> permutationRef = getPermutation();
2027
2028 if (!isPermutationVector(permutationRef))
2029 return emitOpError("permutation is not valid");
2030
2031 auto inputType = getInput().getType();
2032 auto initType = getInit().getType();
2033
2034 int64_t rank = inputType.getRank();
2035
2036 if (rank != initType.getRank())
2037 return emitOpError() << "input rank " << rank
2038 << " does not match init rank " << initType.getRank();
2039
2040 if (rank != static_cast<int64_t>(permutationRef.size()))
2041 return emitOpError() << "size of permutation " << permutationRef.size()
2042 << " does not match the argument rank " << rank;
2043
2044 auto inputDims = inputType.getShape();
2045 auto initDims = initType.getShape();
2046
2047 for (int64_t i = 0; i < rank; ++i) {
2048 int64_t inputDim = inputDims[permutationRef[i]];
2049 int64_t initDim = initDims[i];
2050
2051 if (inputDim != initDim) {
2052 return emitOpError() << "dim(result, " << i << ") = " << initDim
2053 << " doesn't match dim(input, permutation[" << i
2054 << "]) = " << inputDim;
2055 }
2056 }
2057
2058 return success();
2059}
2060
2061SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2062 int64_t rank = getInit().getType().getRank();
2063 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2064}
2065
2066ArrayAttr TransposeOp::getIndexingMaps() {
2067 Builder builder(getContext());
2068 int64_t rank = getInit().getType().getRank();
2069 return builder.getAffineMapArrayAttr(
2071 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2072 builder.getMultiDimIdentityMap(rank)});
2073}
2074
2075void TransposeOp::getEffects(
2076 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2077 &effects) {
2078 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2079}
2080
2081Speculation::Speculatability TransposeOp::getSpeculatability() {
2082 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2083}
2084
2085LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2086 SmallVectorImpl<OpFoldResult> &result) {
2087 // Only the tensor type is supported.
2088 if (!isa<TensorType>(getInput().getType()))
2089 return failure();
2090
2091 // Single dimension transpose.
2092 if (getPermutation().empty()) {
2093 result.push_back(getInput());
2094 return success();
2095 }
2096 // Identity permutation.
2097 if (isIdentityPermutation(getPermutation())) {
2098 result.push_back(getInput());
2099 return success();
2100 }
2101
2102 return failure();
2103}
2104
2105/// Fold transpose with transpose.
2106struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2107 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2108
2109 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2110 PatternRewriter &rewriter) const override {
2111 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2112 if (!defTransposeOp)
2113 return failure();
2114 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2115 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2116 SmallVector<int64_t> foldedPerms;
2117 foldedPerms.reserve(perms.size());
2118 for (int64_t perm : perms)
2119 foldedPerms.push_back(defPerms[perm]);
2120
2121 rewriter.replaceOpWithNewOp<TransposeOp>(
2122 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2123 foldedPerms);
2124 return success();
2125 }
2126};
2127
2128/// This pattern canonicalize transpose by swapping the order of
2129/// broadcast and transpose:
2130/// transpose(broadcast(input)) -> broadcast(transpose(input))
2131struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2132 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2133
2134 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2135 PatternRewriter &rewriter) const override {
2136 Value input = transposeOp.getInput();
2137 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2138 if (!input.hasOneUse() || !broadcastOp)
2139 return failure();
2140
2141 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2142 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2143
2144 // Get new perms and new dimensions.
2145 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2147 SmallVector<int64_t> resultDimensions;
2148 unsigned dimensionSize = dimensions.size();
2149 for (unsigned i = 0; i < dimensionSize; ++i)
2150 resultDimensions.push_back(invertPerm[dimensions[i]]);
2151
2152 // Create transpose result.
2153 Value broadcastInput = broadcastOp.getInput();
2154 Location loc = transposeOp.getLoc();
2155 MLIRContext *ctx = transposeOp.getContext();
2157 auto broadcastInputTy =
2158 mlir::cast<RankedTensorType>(broadcastInput.getType());
2159 unsigned inputRank = broadcastInputTy.getRank();
2160 for (unsigned i = 0; i < inputRank; ++i) {
2161 if (broadcastInputTy.isDynamicDim(i)) {
2162 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2163 ->getResult(0));
2164 } else {
2165 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2166 broadcastInputTy.getDimSize(i)));
2167 }
2168 }
2169 SmallVector<OpFoldResult> transposeResultShapes =
2170 applyPermutation(dims, resultPerms);
2171 Value transposeInit = tensor::EmptyOp::create(
2172 rewriter, transposeOp.getLoc(), transposeResultShapes,
2173 broadcastInputTy.getElementType());
2174
2175 // Create broadcast(transpose(input)).
2176 Value transposeResult =
2177 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2178 transposeInit, resultPerms)
2179 ->getResult(0);
2180 rewriter.replaceOpWithNewOp<BroadcastOp>(
2181 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2182 return success();
2183 }
2184};
2185
2186void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2187 MLIRContext *context) {
2188 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2189}
2190
2191//===----------------------------------------------------------------------===//
2192// BroadcastOp
2193//===----------------------------------------------------------------------===//
2194
2195void BroadcastOp::build(::mlir::OpBuilder &builder,
2196 ::mlir::OperationState &result, Value input, Value init,
2197 DenseI64ArrayAttr dimensions,
2198 ArrayRef<NamedAttribute> attributes) {
2199 result.addOperands(input);
2200 result.addOperands(init);
2201 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2202 result.addAttributes(attributes);
2203
2204 // Add output types for `RankedTensorType` output arguments.
2205 Type initType = init.getType();
2206 if (llvm::isa<RankedTensorType>(initType))
2207 result.addTypes(initType);
2208
2209 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2210 init);
2211}
2212
2213void BroadcastOp::build(::mlir::OpBuilder &builder,
2214 ::mlir::OperationState &result, Value input, Value init,
2215 ArrayRef<int64_t> dimensions,
2216 ArrayRef<NamedAttribute> attributes) {
2217 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2218 attributes);
2219}
2220
2221ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2223 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2224 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2225 })))
2226 return failure();
2227
2228 OpBuilder builder(parser.getContext());
2229 buildIdentityRegion(builder, result.location, *result.addRegion(),
2230 /*inputs=*/result.operands,
2231 /*outputs=*/{});
2232 return success();
2233}
2234
2235void BroadcastOp::getAsmResultNames(
2236 function_ref<void(Value, StringRef)> setNameFn) {
2237 if (!getResults().empty())
2238 setNameFn(getResults().front(), "broadcasted");
2239}
2240
2241void BroadcastOp::print(OpAsmPrinter &p) {
2242 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2243 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2244 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2245}
2246
2247LogicalResult BroadcastOp::verify() {
2248 ArrayRef<int64_t> dimensionsRef = getDimensions();
2249
2250 auto inputType = getInput().getType();
2251 auto initType = getInit().getType();
2252
2253 int64_t inputRank = inputType.getRank();
2254 int64_t initRank = initType.getRank();
2255
2256 auto inputShape = inputType.getShape();
2257 auto initShape = initType.getShape();
2258
2259 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2260 return emitOpError() << "input rank plus added dimensions does not "
2261 "match init rank. input rank: "
2262 << inputRank
2263 << ", dimensions size: " << dimensionsRef.size()
2264 << ", init rank: " << initRank;
2265
2266 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2267 if (dim < 0 || dim >= initRank)
2268 return emitOpError() << "dimension " << idx
2269 << " is out of range. expected range: [0, "
2270 << initRank - 1 << "], got: " << dim;
2271 }
2272
2273 // Mapping from input dims to init dims.
2274 SmallVector<int64_t> dimMap;
2275 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2276 if (!llvm::is_contained(dimensionsRef, dim))
2277 dimMap.push_back(dim);
2278 }
2279
2280 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2281 // This dimensions is mapped from the input. Init and input dims should
2282 // match.
2283 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2284 return emitOpError() << "input dim " << inputDimIdx
2285 << " should match init dim " << initDimIdx
2286 << ". input: " << inputShape[inputDimIdx]
2287 << ", init: " << initShape[initDimIdx];
2288 }
2289
2290 return success();
2291}
2292
2293SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2294 int64_t rank = getInit().getType().getRank();
2295 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2296}
2297
2298ArrayAttr BroadcastOp::getIndexingMaps() {
2299 Builder builder(getContext());
2300 int64_t rank = getInit().getType().getRank();
2301 return builder.getAffineMapArrayAttr(
2302 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2303 builder.getMultiDimIdentityMap(rank)});
2304}
2305
2306void BroadcastOp::getEffects(
2307 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2308 &effects) {
2309 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2310}
2311
2312Speculation::Speculatability BroadcastOp::getSpeculatability() {
2313 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2314}
2315
2316/// Fold back-to-back broadcasts together.
2317struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2318 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2319
2320 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2321 PatternRewriter &rewriter) const override {
2322 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2323 if (!defBroadcastOp)
2324 return failure();
2325 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2326 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2327 SmallVector<int64_t> foldedDims(dimensions);
2328 Value init = broadcastOp.getInit();
2329 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2330 // Mapping from input dims to init dims.
2331 SmallVector<int64_t> dimMap;
2332 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2333 if (!llvm::is_contained(dimensions, dim))
2334 dimMap.push_back(dim);
2335 }
2336 for (auto dim : defDimensions)
2337 foldedDims.push_back(dimMap[dim]);
2338
2339 llvm::sort(foldedDims);
2340 rewriter.replaceOpWithNewOp<BroadcastOp>(
2341 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2342 return success();
2343 }
2344};
2345
2346void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2347 MLIRContext *context) {
2348 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2349}
2350
2351//===----------------------------------------------------------------------===//
2352// YieldOp
2353//===----------------------------------------------------------------------===//
2354
2355void linalg::YieldOp::print(OpAsmPrinter &p) {
2356 if (getNumOperands() > 0)
2357 p << ' ' << getOperands();
2358 p.printOptionalAttrDict((*this)->getAttrs());
2359 if (getNumOperands() > 0)
2360 p << " : " << getOperandTypes();
2361}
2362
2363ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2364 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2365 SmallVector<Type, 2> types;
2366 SMLoc loc = parser.getCurrentLocation();
2367 return failure(parser.parseOperandList(opInfo) ||
2368 parser.parseOptionalAttrDict(result.attributes) ||
2369 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2370 parser.resolveOperands(opInfo, types, loc, result.operands));
2371}
2372
2373// Check the operand number and types must match the element types of the
2374// LinalgOp interface's shaped operands.
2375static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2376 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2377 return op.emitOpError("expected number of yield values (")
2378 << op.getNumOperands()
2379 << ") to match the number of inits / outs operands of the enclosing "
2380 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2381
2382 for (OpOperand &opOperand : op->getOpOperands()) {
2383 OpOperand *outputOperand =
2384 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2385 Type elementType = outputOperand->get().getType();
2386 if (isa<MemRefType, RankedTensorType>(elementType))
2387 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2388 if (opOperand.get().getType() != elementType)
2389 return op.emitOpError("type of yield operand ")
2390 << (opOperand.getOperandNumber() + 1) << " ("
2391 << opOperand.get().getType() << ") doesn't match "
2392 << "the element type of the enclosing linalg.generic op ("
2393 << elementType << ")";
2394 }
2395 return success();
2396}
2397
2398LogicalResult linalg::YieldOp::verify() {
2399 auto *parentOp = (*this)->getParentOp();
2400 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2401 return emitOpError("expected single non-empty parent region");
2402
2403 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2404 return verifyYield(*this, linalgOp);
2405
2406 return emitOpError("expected parent op with LinalgOp interface");
2407}
2408
2409//===----------------------------------------------------------------------===//
2410// IndexOp
2411//===----------------------------------------------------------------------===//
2412
2413LogicalResult IndexOp::verify() {
2414 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2415 if (!linalgOp)
2416 return emitOpError("expected parent op with LinalgOp interface");
2417 if (linalgOp.getNumLoops() <= getDim())
2418 return emitOpError("expected dim (")
2419 << getDim() << ") to be lower than the number of loops ("
2420 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2421 return success();
2422}
2423
2424OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2425 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2426 // Bail out if `linalg.index` does not have a proper parent yet at this
2427 // point, e.g., when calling `createOrFold` during IR construction in
2428 // `genericOp::build`.
2429 if (!linalgOp)
2430 return OpFoldResult{};
2431
2432 // Index of unit dims is always 0.
2433 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2434 uint64_t dim = getDim();
2435 assert(dim < loopBounds.size() && "Dim is out of bounds");
2436 if (loopBounds[dim] == 1)
2437 return IntegerAttr::get(IndexType::get(getContext()), 0);
2438
2439 return OpFoldResult{};
2440}
2441
2442/////// Operations corresponding to library calls defined with Tablegen ////////
2443
2444#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2445
2446#define GET_OP_CLASSES
2447#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2448
2449#define GET_OP_CLASSES
2450#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2451#define GET_OP_CLASSES
2452#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2453
2454AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2455 unsigned rank,
2456 MLIRContext *context) {
2457 if (maybeMap)
2458 return *maybeMap;
2459 if (rank == 0)
2460 return AffineMap::get(context);
2461 return AffineMap::getMultiDimIdentityMap(rank, context);
2462}
2463
2465mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2466 MLIRContext *context) {
2468 res.reserve(num);
2469 for (unsigned i = 0; i < num; ++i)
2470 res.push_back(getAffineDimExpr(startIdx++, context));
2471 return res;
2472}
2473
2476 auto rangeA = llvm::make_range(a.begin(), a.end());
2477 auto rangeB = llvm::make_range(b.begin(), b.end());
2478 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2479 return llvm::to_vector<4>(concatRanges);
2480}
2481
2482static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2483 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2484 ss << "view";
2485 for (auto size : memref.getShape())
2486 if (size < 0)
2487 ss << "sx";
2488 else
2489 ss << size << "x";
2490 if (failed(appendMangledType(ss, memref.getElementType())))
2491 return failure();
2492 if (auto as = memref.getMemorySpace()) {
2493 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2494 ss << "as" << attr.getInt();
2495 else
2496 return failure();
2497 }
2498 return success();
2499 }
2500 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2501 ss << "vector";
2502 llvm::interleave(
2503 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2504 if (failed(appendMangledType(ss, vec.getElementType())))
2505 return failure();
2506 return success();
2507 }
2509 ss << t;
2510 return success();
2511 }
2512 return failure();
2513}
2514
2516 assert(isa<LinalgOp>(op));
2517 std::string name(op->getName().getStringRef().str());
2518 std::string fun = "";
2519 for (NamedAttribute kv : op->getAttrs()) {
2520 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2521 fun = stringifyEnum(ufa.getValue()).str() + "_";
2522 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2523 fun = stringifyEnum(bfa.getValue()).str() + "_";
2524 }
2525 }
2526 name.reserve(128);
2527 llvm::replace(name, '.', '_');
2528 llvm::raw_string_ostream ss(name);
2529 ss << "_" << fun;
2530 for (Type t : op->getOperandTypes()) {
2531 if (failed(appendMangledType(ss, t)))
2532 return std::string();
2533 ss << "_";
2534 }
2535 name.pop_back();
2536 return name;
2537}
2538
2539//===----------------------------------------------------------------------===//
2540// Canonicalizers and Folders.
2541//===----------------------------------------------------------------------===//
2542
2543namespace {
2544struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2546
2547 LogicalResult matchAndRewrite(LinalgOp op,
2548 PatternRewriter &rewriter) const override {
2549 for (OpOperand &opOperand : op->getOpOperands()) {
2550 // Linalg "inputs" may be either tensor or memref type.
2551 // tensor<0xelt_type> is a convention that may not always mean
2552 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2553 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2554 if (!mt)
2555 continue;
2556 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2557 rewriter.eraseOp(op);
2558 return success();
2559 }
2560 }
2561 return failure();
2562 }
2563};
2564
2565/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2566/// result that is more static than the linalg op.
2567struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2568 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2569
2570 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2571 PatternRewriter &rewriter) const override {
2572 if (!tensor::canFoldIntoProducerOp(castOp))
2573 return failure();
2574
2575 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2576 if (!linalgOp)
2577 return failure();
2578
2579 // Cast can be in conditionally reachable region, if which case folding will
2580 // generate invalid code. Only conservatively fold ops in same block for
2581 // now.
2582 if (castOp->getBlock() != linalgOp->getBlock())
2583 return failure();
2584
2585 OpBuilder::InsertionGuard guard(rewriter);
2586 rewriter.setInsertionPoint(linalgOp);
2587
2588 Location loc = linalgOp.getLoc();
2589 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2590 unsigned resultNumber = resultValue.getResultNumber();
2591 auto resultType =
2592 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2593 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2594 // going from a more dynamic shape to a less dynamic shape. If the producer
2595 // for this cast, i.e. producer of the out operand, is also an operation
2596 // that folds with tensor.cast consumer (like this pattern), the cast will
2597 // continue to propagate as far up the stack as it can go.
2598 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2599 Value newOperand =
2600 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2601 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2602 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2603 linalgOp.getDpsInits().end());
2604 outputOperands[resultNumber] = newOperand;
2605 newOperands.append(outputOperands.begin(), outputOperands.end());
2606
2607 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2608 linalgOp->result_type_end());
2609 resultTypes[resultNumber] = resultType;
2610 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2611
2612 // Create a tensor.cast operation back to the original type.
2613 Value castBack = tensor::CastOp::create(
2614 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2615
2616 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2617 results[resultNumber] = castBack;
2618 rewriter.replaceOp(linalgOp, results);
2619 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2620 return success();
2621 }
2622};
2623
2624/// For each of the operand in `operands` this function maps the static sizes of
2625/// dimensions to their affine dim expressions.
2626static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2627 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2628 for (OpOperand &opOperand : operands) {
2629 if (linalgOp.isScalar(&opOperand))
2630 continue;
2631 Value src = opOperand.get();
2632 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2633 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2634
2635 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2636 // `tensor.cast` operation and source of the cast operation has a static
2637 // shape, then assign it to the `sourceShape`.
2638 auto *parentOp = src.getDefiningOp();
2639 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2640 if (parentOp) {
2641 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2642 Value castSource = castOp.getSource();
2643 auto castSourceType =
2644 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2645 if (castSourceType && castSourceType.hasStaticShape())
2646 sourceShape = castSourceType.getShape();
2647 }
2648 }
2649
2650 // If the source shape's dimension has a static shape, map the affine dim
2651 // expression to the known static size.
2652 for (unsigned i = 0; i < sourceShape.size(); i++) {
2653 if (sourceType.isDynamicDim(i))
2654 continue;
2655 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2656 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2657 }
2658 }
2659}
2660
2661/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2662/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2663/// their result types is stored in `resultTypes`. If `opOperand` requires no
2664/// change then `changeNeeded` is false and same operand is added in the
2665/// `newOperands` list.
2666static void createNewOperandWithStaticSizes(
2667 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2668 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2669 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2670 bool &changeNeeded) {
2671 Value src = opOperand->get();
2672 newOperands.push_back(src);
2673 if (linalgOp.isScalar(opOperand))
2674 return;
2675 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2676 Type resultType = sourceType;
2677 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2678 resultTypes.push_back(resultType);
2679 return;
2680 }
2681 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2682 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2683 SmallVector<int64_t> newShape;
2684 // If operand is updated with new shape, `newOperandNeeded` will be
2685 // true.
2686 bool newOperandNeeded = false;
2687 for (unsigned i = 0; i < sourceShape.size(); i++) {
2688 int64_t dimShape = sourceShape[i];
2689 AffineExpr dimExpr = sourceMap.getResult(i);
2690 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2691 newShape.push_back(dimShape);
2692 continue;
2693 }
2694 // Dimension has a dynamic shape and corresponding affine dim
2695 // expression is present in the map. So assign the size for the
2696 // given affine dim expression to the dimension.
2697 newShape.push_back(affineExprToSize[dimExpr]);
2698 newOperandNeeded = true;
2699 }
2700 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2701 sourceType.getEncoding());
2702 if (newOperandNeeded) {
2703 changeNeeded = true;
2704 // Get the new operand value given its size and element type by
2705 // casting it.
2706 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2707 unsigned index = opOperand->getOperandNumber();
2708 newOperands[index] = newOperand;
2709 }
2710 if (linalgOp.isDpsInit(opOperand))
2711 resultTypes.push_back(resultType);
2712}
2713
2714/// Static shapes for the operands can be inferred if any one of the operands
2715/// have a static shape. This can be done by referring to the affine dim
2716/// expressions for the operand.
2717struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2718 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2719
2720 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2721 PatternRewriter &rewriter) const override {
2722 if (!linalgOp.hasPureTensorSemantics())
2723 return failure();
2724
2725 // Maps must be projected permutations.
2726 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2727 return !map.isProjectedPermutation();
2728 }))
2729 return failure();
2730
2731 // Maps affine dim expressions to the static size of that dimension.
2732 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2733 Location loc = linalgOp.getLoc();
2734
2735 // For each of the affine dim expression, check if the size is known. If
2736 // known add that in the map.
2737 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2738
2739 SmallVector<Value> newOperands;
2740 SmallVector<Type> resultTypes;
2741
2742 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2743 // change in their types.
2744 bool changeNeeded = false;
2745 newOperands.reserve(linalgOp->getNumOperands());
2746 resultTypes.reserve(linalgOp.getNumDpsInits());
2747
2748 // Iterate over all the operands and update the static sizes.
2749 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2750 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2751 affineExprToSize, linalgOp, newOperands,
2752 resultTypes, changeNeeded);
2753 }
2754
2755 // If the generic op has all the required static information, no
2756 // canonicalization needed.
2757 if (!changeNeeded)
2758 return failure();
2759
2760 // Clone op.
2761 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2762 SmallVector<Value> replacements;
2763 replacements.reserve(newOp->getNumResults());
2764 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2765 Value newResult = std::get<1>(it);
2766 Value oldResult = std::get<0>(it);
2767 Type newType = newResult.getType();
2768 Type oldType = oldResult.getType();
2769 replacements.push_back(
2770 (newType != oldType)
2771 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2772 : newResult);
2773 }
2774 rewriter.replaceOp(linalgOp, replacements);
2775 return success();
2776 }
2777};
2778
2779} // namespace
2780
2781// All named ops canonicalizers and folders are auto-generated in the
2782// .cpp.inc.
2783
2784//===----------------------------------------------------------------------===//
2785// SoftmaxOp
2786//===----------------------------------------------------------------------===//
2787
2788LogicalResult SoftmaxOp::verify() {
2789 ShapedType inputType = getInputOperandType();
2790 ShapedType outputType = getOutputOperandType();
2791
2792 ArrayRef<int64_t> inputShape = inputType.getShape();
2793 ArrayRef<int64_t> outputShape = outputType.getShape();
2794 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2795 return emitOpError("incompatible output shape");
2796
2797 int64_t inputRank = getInputOperandRank();
2798 int64_t dimension = getDimension();
2799 if ((dimension < 0) || (dimension >= inputRank))
2800 return emitOpError("incorrect dimension specified");
2801
2802 return success();
2803}
2804
2805SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2806 int64_t operandRank = getInputOperandRank();
2807 SmallVector<Range> loopBounds(operandRank);
2808 Location loc = getLoc();
2809 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2810 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2811 Value source = getInput();
2812 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2813 loopBounds[dim].offset = zero;
2814 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2815 loopBounds[dim].stride = one;
2816 }
2817 return loopBounds;
2818}
2819
2820SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2821 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2822 utils::IteratorType::parallel);
2823 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2824 return iteratorTypes;
2825}
2826
2827FailureOr<TilingResult>
2828SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2829 ArrayRef<OpFoldResult> offsets,
2830 ArrayRef<OpFoldResult> sizes) {
2831 int64_t rank = getInputOperandRank();
2832 auto oneAttr = builder.getI64IntegerAttr(1);
2833 SmallVector<OpFoldResult> strides(rank, oneAttr);
2834 SmallVector<Value> tiledOperands;
2835 Operation *inputSlice =
2836 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2837 if (!inputSlice) {
2838 return emitOpError("failed to compute input slice");
2839 }
2840 tiledOperands.emplace_back(inputSlice->getResult(0));
2841 Operation *outputSlice =
2842 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2843 if (!outputSlice) {
2844 return emitOpError("failed to compute output slice");
2845 }
2846 tiledOperands.emplace_back(outputSlice->getResult(0));
2847
2848 SmallVector<Type, 4> resultTypes;
2849 if (hasPureTensorSemantics())
2850 resultTypes.push_back(tiledOperands[1].getType());
2851 Operation *tiledOp =
2852 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2853
2854 return TilingResult{
2855 {tiledOp},
2856 SmallVector<Value>(tiledOp->getResults()),
2857 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2858}
2859
2860LogicalResult SoftmaxOp::getResultTilePosition(
2861 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2862 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2863 SmallVector<OpFoldResult> &resultSizes) {
2864 if (resultNumber == 0) {
2865 resultOffsets.assign(offsets.begin(), offsets.end());
2866 resultSizes.assign(sizes.begin(), sizes.end());
2867 return success();
2868 }
2869 return failure();
2870}
2871
2872// cast(dynamic) -> static.
2873LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2874 return memref::foldMemRefCast(*this);
2875}
2876
2877LogicalResult
2878SoftmaxOp::reifyResultShapes(OpBuilder &b,
2879 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2880 SmallVector<OpFoldResult> shapes;
2881 Location loc = getOperation()->getLoc();
2882 IRRewriter rewriter(b);
2883 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2884 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2885 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2886 if (!outputShapedType.isDynamicDim(dim)) {
2887 // Static dim: Return IntegerAttr.
2888 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2889 } else {
2890 // Dynamic dim: Return Value.
2891 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2892 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2893 }
2894 }
2895 reifiedReturnShapes.emplace_back(std::move(shapes));
2896 return success();
2897}
2898
2899void SoftmaxOp::getEffects(
2900 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2901 &effects) {
2902 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2903 if (!llvm::isa<MemRefType>(operand.getType()))
2904 continue;
2905 effects.emplace_back(MemoryEffects::Read::get(),
2906 &getOperation()->getOpOperand(index), /*stage=*/0,
2907 /*effectOnFullRegion=*/true,
2909 }
2910
2911 for (OpOperand &operand : getDpsInitsMutable()) {
2912 if (!llvm::isa<MemRefType>(operand.get().getType()))
2913 continue;
2914 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2915 /*effectOnFullRegion=*/true,
2917 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2918 /*effectOnFullRegion=*/true,
2920 }
2921}
2922
2923// Helper functions for softmax decomposition.
2924// @{
2925
2926// Helper function to produce the iterator types (reduction or parallel) and
2927// affine maps for the iterators used in the decomposition of softmax.
2928// This method creates:
2929// If allParallel == true:
2930// - iterator type: {parallel, ..., parallel}
2931// - affine maps:
2932// -- identity with inputRank dimensions.
2933// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2934// where N == inputRank.
2935//
2936// If allParallel == false:
2937// - iterator type at dim(i) == parallel for i != \p dim and
2938// dim(dim) == reduction.
2939// - affine map:
2940// -- identity with inputRank dimensions.
2941// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2942// where N == inputRank.
2943static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2945 int64_t dim, bool allParallel = false) {
2946 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2947 utils::IteratorType::parallel);
2948 if (!allParallel)
2949 iteratorTypes[dim] = utils::IteratorType::reduction;
2950 MLIRContext *ctxt = builder.getContext();
2951 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2952 SmallVector<AffineExpr, 2> affineExprs;
2953 for (int i = 0; i < inputRank; i++) {
2954 if (i != dim)
2955 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2956 }
2957 auto reductionMap =
2958 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2959 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2960 return std::make_tuple(iteratorTypes, indexingMaps);
2961}
2962
2963// Helper function to produce a linalg.generic that computes a reduction on
2964// dimension \p dim with the operation type \p T.
2965template <typename T>
2966static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2967 int64_t dim) {
2968 auto inputType = cast<ShapedType>(input.getType());
2969 ArrayRef<int64_t> inputShape = inputType.getShape();
2970 int64_t inputRank = inputShape.size();
2971 auto [iteratorTypes, indexingMaps] =
2972 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2973 assert(indexingMaps.size() == 2 &&
2974 "We should have two maps: 1 for the input, 1 for the output");
2975 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2976
2977 auto genericOp = linalg::GenericOp::create(
2978 builder, loc, output.getType(), input, output, indexingMaps,
2979 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2980 Value result = T::create(b, loc, args[0], args[1]);
2981 linalg::YieldOp::create(b, loc, result);
2982 });
2983 return genericOp.getResult(0);
2984}
2985
2986/// Produce a linalg generic that computes the second step of the softmax
2987/// decomposition: res = exp(input - max), where \p max is the max of \p input
2988/// on dimension \p dim.
2989static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2990 Value max, Value output, int64_t dim) {
2991 auto inputType = cast<ShapedType>(input.getType());
2992 ArrayRef<int64_t> inputShape = inputType.getShape();
2993 int64_t inputRank = inputShape.size();
2994 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2995 builder, inputRank, dim, /*allParallel=*/true);
2996 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2997 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2998 // Add the affine map for the output argument.
2999 indexingMaps.push_back(indexingMaps[0]);
3000 auto genericOp = linalg::GenericOp::create(
3001 builder, loc, input.getType(), ValueRange{input, max}, output,
3002 indexingMaps, iteratorTypes,
3003 [&](OpBuilder &b, Location loc, ValueRange args) {
3004 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3005 Value result = math::ExpOp::create(b, loc, diff);
3006 linalg::YieldOp::create(b, loc, result);
3007 });
3008 return genericOp.getResult(0);
3009}
3010
3011/// Produce a linalg generic that computes the final step of the softmax
3012/// decomposition.
3013/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3014/// yield n / d
3015/// }
3016static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3017 Value denominator, Value output, int64_t dim) {
3018 auto inputType = cast<ShapedType>(numerator.getType());
3019 ArrayRef<int64_t> inputShape = inputType.getShape();
3020 int64_t inputRank = inputShape.size();
3021 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3022 builder, inputRank, dim, /*allParallel=*/true);
3023 assert(indexingMaps.size() == 2 &&
3024 "We should have one map for each input (2)");
3025 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3026 // Add the affine map for the output tensor.
3027 indexingMaps.push_back(indexingMaps[0]);
3028 auto genericOp = linalg::GenericOp::create(
3029 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3030 output, indexingMaps, iteratorTypes,
3031 [&](OpBuilder &b, Location loc, ValueRange args) {
3032 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3033 linalg::YieldOp::create(b, loc, result);
3034 });
3035 return genericOp.getResult(0);
3036}
3037// @} End helper functions for softmax decomposition.
3038
3039/// Given an N-dimensional tensor x, this method converts
3040/// softmax(x) to the following sequence of operations:
3041///
3042/// 1. Compute the max of x along dimension d. This results
3043/// in a N-1 dimensional tensor m.
3044/// m = max(x, dim = d)
3045///
3046/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3047/// a N dimensional tensor z.
3048/// z = exp(x - m)
3049///
3050/// 3. Compute the sum of z along dimension d. This results in
3051/// a N-1 dimensional tensor l.
3052/// l = sum(z, dim = d)
3053///
3054/// 4. Divide z and l. This gives the N-dimensional softmax.
3055/// softmax = z / l
3056///
3057FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3058 OpBuilder::InsertionGuard guard(b);
3059 b.setInsertionPoint(*this);
3060 Location loc = getLoc();
3061 Value input = getInput();
3062 ShapedType inputType = getInputOperandType();
3063 Type elementType = inputType.getElementType();
3064 int64_t reductionDim = getDimension();
3065 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3066 Value output = getOutput();
3067 dims.erase(dims.begin() + reductionDim);
3068 // Step 1: Compute max along dim.
3069 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3070 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3071 elementType, b, loc,
3072 /*useOnlyFiniteValue=*/true);
3073 Value neutralForMaxFInit =
3074 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3075 .result();
3076 Value max =
3077 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3078
3079 // Step 2: Subtract max from input and exponentiate.
3080 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3081
3082 // Step 3: Compute sum along dim.
3083 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3084 b, loc, /*useOnlyFiniteValue=*/true);
3085 Value zeroInit =
3086 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3087 Value denominator =
3088 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3089
3090 // Step 4: Compute softmax.
3091 Value result =
3092 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3093 return SmallVector<Value>{result};
3094}
3095
3096//===----------------------------------------------------------------------===//
3097// WinogradFilterTransformOp
3098//===----------------------------------------------------------------------===//
3099
3100LogicalResult WinogradFilterTransformOp::verify() {
3101 auto filterType = cast<ShapedType>(getFilter().getType());
3102 ArrayRef<int64_t> filterShape = filterType.getShape();
3103 int64_t filterH = filterShape[getFilterHDim()];
3104 int64_t filterW = filterShape[getFilterWDim()];
3105 WinogradConv2DFmr fmr = getFmr();
3106 int64_t m, r;
3107 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3108
3109 if (filterH != r && filterH != 1)
3110 return emitOpError("expect filter height either equals to r or 1");
3111 if (filterW != r && filterW != 1)
3112 return emitOpError("expect filter width either equals to r or 1");
3113 if (filterH == 1 && filterW == 1)
3114 return emitOpError("expect either filter height or width equals to r");
3115
3116 SmallVector<int64_t> expectedOutputShape;
3117 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3118 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3119 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3120 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3121
3122 auto outputType = cast<ShapedType>(getOutput().getType());
3123 ArrayRef<int64_t> outputShape = outputType.getShape();
3124 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3125 return emitOpError("the output shape is not expected");
3126 }
3127 return success();
3128}
3129
3130SmallVector<Range>
3131WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3132 Location loc = getLoc();
3133 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3134 IntegerAttr oneAttr = builder.getIndexAttr(1);
3135 Value filter = getFilter();
3136 int64_t filterRank = getFilterOperandRank();
3137 SmallVector<Range> loopBounds(filterRank);
3138 for (unsigned dim = 0; dim < filterRank; ++dim) {
3139 loopBounds[dim].offset = zeroAttr;
3140 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3141 loopBounds[dim].stride = oneAttr;
3142 }
3143 return loopBounds;
3144}
3145
3146SmallVector<utils::IteratorType>
3147WinogradFilterTransformOp::getLoopIteratorTypes() {
3148 int64_t filterRank = getFilterOperandRank();
3149 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3150 utils::IteratorType::parallel);
3151 return iteratorTypes;
3152}
3153
3154LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3155 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3156 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3157 SmallVector<OpFoldResult> &resultSizes) {
3158 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3159 ShapedType filterType = getFilterOperandType();
3160 ArrayRef<int64_t> filterShape = filterType.getShape();
3161 int64_t filterH = filterShape[getFilterHDim()];
3162 int64_t filterW = filterShape[getFilterWDim()];
3163 WinogradConv2DFmr fmr = getFmr();
3164 int64_t m, r;
3165 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3166 int64_t alpha = m + r - 1;
3167 int64_t alphaH = filterH != 1 ? alpha : 1;
3168 int64_t alphaW = filterW != 1 ? alpha : 1;
3169 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3170 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3171
3172 resultOffsets.append(
3173 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3174 resultSizes.append(
3175 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3176
3177 return success();
3178}
3179
3180/// Implement tiling for winograd_filter_transform
3181/// The input of winograd_filter_transform is (F, KH, KW, C).
3182/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3183/// Users can specify the tile sizes of F and C.
3184/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3185/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3186FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3187 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3188 ArrayRef<OpFoldResult> sizes) {
3189 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3190 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3191 ShapedType filterType = getFilterOperandType();
3192 ArrayRef<int64_t> filterShape = filterType.getShape();
3193 int64_t filterH = filterShape[getFilterHDim()];
3194 int64_t filterW = filterShape[getFilterWDim()];
3195 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3196 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3197 SmallVector<Value> tiledOperands;
3198 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3199
3200 sliceOffsets.append(
3201 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3202 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3203 sizes[getFilterCDim()]});
3204 int64_t filterRank = getFilterOperandRank();
3205 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3206 Location loc = getLoc();
3207 auto filterSlice = tensor::ExtractSliceOp::create(
3208 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3209 tiledOperands.emplace_back(filterSlice);
3210
3211 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3212 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3213 resultSizes)))
3214 return failure();
3215
3216 int64_t outputRank = getOutputOperandRank();
3217 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3218 auto outputSlice = tensor::ExtractSliceOp::create(
3219 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3220 tiledOperands.emplace_back(outputSlice);
3221
3222 SmallVector<Type> resultTypes;
3223 resultTypes.push_back(tiledOperands[1].getType());
3224 Operation *tiledOp =
3225 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3226
3227 return TilingResult{
3228 {tiledOp},
3229 SmallVector<Value>(tiledOp->getResults()),
3230 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3231}
3232
3233//===----------------------------------------------------------------------===//
3234// WinogradInputTransformOp
3235//===----------------------------------------------------------------------===//
3236
3237LogicalResult WinogradInputTransformOp::verify() {
3238 auto inputType = cast<ShapedType>(getInput().getType());
3239 ArrayRef<int64_t> inputShape = inputType.getShape();
3240 int64_t inputH = inputShape[getInputHDim()];
3241 int64_t inputW = inputShape[getInputWDim()];
3242 WinogradConv2DFmr fmr = getFmr();
3243 int64_t m, r;
3244 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3245 int64_t tileSize = m + r - 1;
3246
3247 auto outputType = cast<ShapedType>(getOutput().getType());
3248 ArrayRef<int64_t> outputShape = outputType.getShape();
3249 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3250 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3251
3252 SmallVector<int64_t> expectedOutputShape(6, inputH);
3253 if (ShapedType::isDynamic(inputH)) {
3254 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3255 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3256 } else {
3257 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3258 expectedOutputShape[getOutputTileHDim()] =
3259 leftTransform ? (inputH - (r - 1)) / m : inputH;
3260 }
3261 if (ShapedType::isDynamic(inputW)) {
3262 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3263 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3264 } else {
3265 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3266 expectedOutputShape[getOutputTileWDim()] =
3267 rightTransform ? (inputW - (r - 1)) / m : inputW;
3268 }
3269 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3270 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3271
3272 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3273 return emitOpError("the output shape is not expected");
3274 }
3275 return success();
3276}
3277
3278SmallVector<Range>
3279WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3280 Location loc = getLoc();
3281 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3282 IntegerAttr oneAttr = builder.getIndexAttr(1);
3283 Value output = getOutput();
3284 int64_t outputRank = getOutputOperandRank();
3285 SmallVector<Range> loopBounds(outputRank);
3286 for (unsigned dim = 0; dim < outputRank; ++dim) {
3287 loopBounds[dim].offset = zeroAttr;
3288 // alphaH, alphaW, tileH, tileW, N, C
3289 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3290 loopBounds[dim].stride = oneAttr;
3291 }
3292 return loopBounds;
3293}
3294
3295SmallVector<utils::IteratorType>
3296WinogradInputTransformOp::getLoopIteratorTypes() {
3297 int64_t outputRank = getOutputOperandRank();
3298 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3299 utils::IteratorType::parallel);
3300 return iteratorTypes;
3301}
3302
3303LogicalResult WinogradInputTransformOp::getResultTilePosition(
3304 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3305 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3306 SmallVector<OpFoldResult> &resultSizes) {
3307 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3308 ShapedType outputType = getOutputOperandType();
3309 ArrayRef<int64_t> outputShape = outputType.getShape();
3310 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3311 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3312
3313 WinogradConv2DFmr fmr = getFmr();
3314 int64_t m, r;
3315 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3316 int64_t alpha = m + r - 1;
3317 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3318 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3319
3320 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3321 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3322
3323 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3324 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3325 offsets[getOutputCDim()]});
3326 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3327 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3328 sizes[getOutputCDim()]});
3329
3330 return success();
3331}
3332
3333/// Implement tiling for winograd_input_transform
3334/// The input of winograd_input_transform is (N, H, W, C).
3335/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3336/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3337/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3338/// the values for the sizes of tileH, tileW, N, C for one tile.
3339FailureOr<TilingResult>
3340WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3341 ArrayRef<OpFoldResult> offsets,
3342 ArrayRef<OpFoldResult> sizes) {
3343 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3344 WinogradConv2DFmr fmr = getFmr();
3345 int64_t m, r;
3346 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3347
3348 ShapedType outputType = getOutputOperandType();
3349 ArrayRef<int64_t> outputShape = outputType.getShape();
3350 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3351 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3352
3353 Location loc = getLoc();
3354 MLIRContext *context = builder.getContext();
3355 auto identityAffineMap =
3356 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3357 auto offsetAffineMap =
3358 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3359 Value mappedOffsetH = affine::makeComposedAffineApply(
3360 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3361 offsets[getOutputTileHDim()]);
3362 Value mappedOffsetW = affine::makeComposedAffineApply(
3363 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3364 offsets[getOutputTileWDim()]);
3365 auto sizeAffineMap = AffineMap::get(
3366 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3367 Value mappedSizeH = affine::makeComposedAffineApply(
3368 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3369 Value mappedSizeW = affine::makeComposedAffineApply(
3370 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3371
3372 SmallVector<Value> tiledOperands;
3373 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3374
3375 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3376 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3377 sliceOffsets.append(
3378 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3379 OpFoldResult sizeH =
3380 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3381 OpFoldResult sizeW =
3382 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3383 sliceSizes.append(
3384 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3385 int64_t inputRank = getInputOperandRank();
3386 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3387 auto inputSlice = tensor::ExtractSliceOp::create(
3388 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3389 tiledOperands.emplace_back(inputSlice);
3390
3391 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3392 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3393 resultSizes)))
3394 return failure();
3395
3396 int64_t outputRank = getOutputOperandRank();
3397 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3398 auto outputSlice = tensor::ExtractSliceOp::create(
3399 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3400 tiledOperands.emplace_back(outputSlice);
3401
3402 SmallVector<Type> resultTypes;
3403 resultTypes.push_back(tiledOperands[1].getType());
3404 Operation *tiledOp =
3405 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3406
3407 return TilingResult{
3408 {tiledOp},
3409 SmallVector<Value>(tiledOp->getResults()),
3410 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3411}
3412
3413//===----------------------------------------------------------------------===//
3414// WinogradOutputTransformOp
3415//===----------------------------------------------------------------------===//
3416
3417LogicalResult WinogradOutputTransformOp::verify() {
3418 auto valueType = cast<ShapedType>(getValue().getType());
3419 ArrayRef<int64_t> valueShape = valueType.getShape();
3420 int64_t valueH = valueShape[getValueAlphaHDim()];
3421 int64_t valueW = valueShape[getValueAlphaWDim()];
3422 int64_t valueTileH = valueShape[getValueTileHDim()];
3423 int64_t valueTileW = valueShape[getValueTileWDim()];
3424 WinogradConv2DFmr fmr = getFmr();
3425 int64_t m, r;
3426 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3427 bool leftTransform = valueH != 1;
3428 bool rightTransform = valueW != 1;
3429
3430 int64_t outputRank = getOutputOperandRank();
3431 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3432 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3433 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3434 } else {
3435 if (valueH != (leftTransform ? m + r - 1 : 1))
3436 return emitOpError("expect input height equals to input tile size");
3437 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3438 }
3439 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3440 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3441 } else {
3442 if (valueW != (rightTransform ? m + r - 1 : 1))
3443 return emitOpError("expect input width equals to input tile size");
3444 expectedOutputShape[getOutputWDim()] =
3445 (rightTransform ? m : 1) * valueTileW;
3446 }
3447 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3448 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3449
3450 auto outputType = cast<ShapedType>(getOutput().getType());
3451 ArrayRef<int64_t> outputShape = outputType.getShape();
3452 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3453 return emitOpError("the output shape is not expected");
3454 }
3455 return success();
3456}
3457
3458SmallVector<Range>
3459WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3460 Location loc = getLoc();
3461 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3462 IntegerAttr oneAttr = builder.getIndexAttr(1);
3463 Value value = getValue();
3464 int64_t valueRank = getValueOperandRank();
3465 SmallVector<Range> loopBounds(valueRank);
3466 for (unsigned dim = 0; dim < valueRank; ++dim) {
3467 loopBounds[dim].offset = zeroAttr;
3468 // alphaH, alphaW, tileH, tileW, N, F
3469 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3470 loopBounds[dim].stride = oneAttr;
3471 }
3472 return loopBounds;
3473}
3474
3475SmallVector<utils::IteratorType>
3476WinogradOutputTransformOp::getLoopIteratorTypes() {
3477 int64_t valueRank = getValueOperandRank();
3478 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3479 utils::IteratorType::parallel);
3480 return iteratorTypes;
3481}
3482
3483LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3484 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3485 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3486 SmallVector<OpFoldResult> &resultSizes) {
3487 WinogradConv2DFmr fmr = getFmr();
3488 int64_t m, r;
3489 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3490
3491 Location loc = getLoc();
3492 MLIRContext *context = builder.getContext();
3493 auto identityAffineMap =
3494 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3495 auto affineMap =
3496 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3497
3498 ShapedType valueType = getValueOperandType();
3499 ArrayRef<int64_t> valueShape = valueType.getShape();
3500 int64_t valueH = valueShape[0];
3501 int64_t valueW = valueShape[1];
3502 Value mappedOffsetH = affine::makeComposedAffineApply(
3503 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3504 offsets[getValueTileHDim()]);
3505 Value mappedOffsetW = affine::makeComposedAffineApply(
3506 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3507 offsets[getValueTileWDim()]);
3508 Value mappedSizeH = affine::makeComposedAffineApply(
3509 builder, loc, affineMap, sizes[getValueTileHDim()]);
3510 Value mappedSizeW = affine::makeComposedAffineApply(
3511 builder, loc, affineMap, sizes[getValueTileWDim()]);
3512
3513 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3514 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3515 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3516 OpFoldResult sizeH =
3517 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3518 OpFoldResult sizeW =
3519 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3520
3521 resultOffsets.append(
3522 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3523 resultSizes.append(
3524 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3525 return success();
3526}
3527
3528/// Implement tiling for winograd_output_transform
3529/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3530/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3531/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3532/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3533/// for the sizes of tileH, tileW, N, F for one tile.
3534FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3535 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3536 ArrayRef<OpFoldResult> sizes) {
3537 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3538 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3539 Location loc = getLoc();
3540 SmallVector<Value> tiledOperands;
3541 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3542
3543 ShapedType valueType = getValueOperandType();
3544 ArrayRef<int64_t> valueShape = valueType.getShape();
3545 int64_t alphaH = valueShape[getValueAlphaHDim()];
3546 int64_t alphaW = valueShape[getValueAlphaWDim()];
3547 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3548 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3549
3550 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3551 offsets[getValueTileWDim()], offsets[getValueNDim()],
3552 offsets[getValueFDim()]});
3553 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3554 sizes[getValueTileWDim()], sizes[getValueNDim()],
3555 sizes[getValueFDim()]});
3556 int64_t valueRank = getValueOperandRank();
3557 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3558 auto valueSlice = tensor::ExtractSliceOp::create(
3559 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3560 tiledOperands.emplace_back(valueSlice);
3561
3562 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3563 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3564 resultSizes)))
3565 return failure();
3566
3567 int64_t outputRank = getOutputOperandRank();
3568 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3569 auto outputSlice = tensor::ExtractSliceOp::create(
3570 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3571 tiledOperands.emplace_back(outputSlice);
3572
3573 SmallVector<Type> resultTypes;
3574 resultTypes.push_back(tiledOperands[1].getType());
3575 Operation *tiledOp =
3576 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3577
3578 return TilingResult{
3579 {tiledOp},
3580 SmallVector<Value>(tiledOp->getResults()),
3581 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3582}
3583
3584//===----------------------------------------------------------------------===//
3585// LinalgDialect
3586// TODO: Merge with the LinalgDialect block at the bottom
3587//===----------------------------------------------------------------------===//
3588
3589// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3590static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3591 auto explicitRange = subMap.getResults();
3592 auto defaultRange = fullMap.getResults();
3593 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3594 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3595 llvm::set_union(explicitSet, defaultSet);
3596 return explicitSet == defaultSet;
3597}
3598
3599/// Check if the user defined map is valid broadcast map. Here broadcast
3600/// indexing maps are defined in context of corresponding default indexing maps
3601/// for the given Op. This way the check becomes very simple i.e just check the
3602/// number of result dims.
3603/// Returns true if the explictMap is broadcasted with respect to the
3604/// defaultMap.
3605static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3606 return explictMap.getNumResults() < defaultMap.getNumResults();
3607}
3608
3609/// Verifies the broadcast and transpose semantic sepecified by the explicit
3610/// indexing map for the MatmulOp \p op for each operand specified by \p
3611/// opIndex.
3612static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3613 unsigned opIndex) {
3614 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3615 SmallVector<AffineMap, 3> defaultIndexingMaps =
3616 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3617
3618 auto opIndexingMap = opIndexingMaps[opIndex];
3619 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3620 // Check general validity of indexing map results.
3621 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3622 return matmulOp->emitOpError()
3623 << "Unexpected dim expression in map result.";
3624
3625 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3626 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3627 return matmulOp->emitOpError()
3628 << "Invalid broadcast requested, should be (d2).";
3629 }
3630 return success();
3631 }
3632 return success();
3633}
3634
3635// Check general validity of input indexing map of
3636// BatchMatmulOp/BatchReduceMatmulOp.
3637template <typename OpTy>
3638static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3639 AffineMap opIndexingMap,
3640 AffineMap defaultIndexingMap, bool isLHS) {
3641 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3642 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3643 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3644 // Check the result dims are valid.
3645 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3646 return batchVariantMatmulOp->emitOpError()
3647 << "Unexpected result dim expression (outside the set of default "
3648 "result dims).";
3649
3650 // Check for valid number of result dims of input maps.
3651 if (opIndexingMap.getNumResults() > 3)
3652 return batchVariantMatmulOp->emitOpError()
3653 << "no. of result dim expressions exceeds 3.";
3654
3655 auto hasValidBatchDim = [](AffineMap map) {
3656 AffineExpr batchDim = map.getResult(0);
3657 return batchDim.isFunctionOfDim(0);
3658 };
3659
3660 // Check if the requested broadcast is valid.
3661 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3662 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3663 return batchVariantMatmulOp->emitOpError()
3664 << "Invalid broadcast requested.";
3665 } else if (!hasValidBatchDim(opIndexingMap)) {
3666 return batchVariantMatmulOp->emitOpError()
3667 << "Invalid batch dimension expression.";
3668 }
3669 return success();
3670}
3671
3672/// This function checks if the given AffineMap for the output of a
3673/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3674/// dimensions and if the output map result dimensions are valid.
3675template <typename OpTy>
3676static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3677 AffineMap opIndexingMap) {
3678 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3679 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3680 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3681 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3682 opIndexingMap.getNumResults() != 3) {
3683
3684 return batchVariantMatmulOp->emitOpError()
3685 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3686 << ").";
3687 }
3688 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3689 opIndexingMap.getNumResults() != 2) {
3690 return batchVariantMatmulOp->emitOpError()
3691 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3692 << ").";
3693 }
3694
3695 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3696 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3697 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3698 outputMap.getResult(1).isFunctionOfDim(1) &&
3699 outputMap.getResult(2).isFunctionOfDim(2)
3700 : outputMap.getResult(0).isFunctionOfDim(1) &&
3701 outputMap.getResult(1).isFunctionOfDim(2);
3702 };
3703
3704 if (!areValidOutputResultDim(opIndexingMap)) {
3705 return batchVariantMatmulOp->emitOpError()
3706 << "Invalid output map result dimension.";
3707 }
3708
3709 return success();
3710}
3711
3712/// Verifies the broadcast and transpose semantic specified by the explicit
3713/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3714/// specified by opIndex.
3715template <typename OpTy>
3716static LogicalResult
3718 unsigned opIndex) {
3719 SmallVector<AffineMap, 3> opIndexingMaps =
3720 batchVariantMatmulOp.getIndexingMapsArray();
3721 SmallVector<AffineMap, 3> defaultIndexingMaps =
3722 batchVariantMatmulOp.getDefaultIndexingMaps(
3723 batchVariantMatmulOp->getContext());
3724
3725 if (opIndexingMaps.size() != 3)
3726 return batchVariantMatmulOp->emitOpError()
3727 << "Indexing_map attribute must have 3 affine maps.";
3728
3729 auto opIndexingMap = opIndexingMaps[opIndex];
3730 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3731
3732 if (opIndex == 2 &&
3733 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3734 return failure();
3735
3736 if (opIndex != 2 &&
3737 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3738 defaultIndexingMap, opIndex == 0)))
3739 return failure();
3740
3741 return success();
3742}
3743
3744namespace mlir {
3745namespace linalg {
3746
3747std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3748 if (m == 2 && r == 3)
3749 return WinogradConv2DFmr::F_2_3;
3750 if (m == 4 && r == 3)
3751 return WinogradConv2DFmr::F_4_3;
3752 if (m == 2 && r == 5)
3753 return WinogradConv2DFmr::F_2_5;
3754 return std::nullopt;
3755}
3756
3757std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3758 switch (fmr) {
3759 case WinogradConv2DFmr::F_2_3:
3760 return {2, 3};
3761 case WinogradConv2DFmr::F_4_3:
3762 return {4, 3};
3763 case WinogradConv2DFmr::F_2_5:
3764 return {2, 5};
3765 }
3766}
3767
3768//===----------------------------------------------------------------------===//
3769// MatMulOp
3770//===----------------------------------------------------------------------===//
3771
3772static FailureOr<SmallVector<SmallVector<int64_t>>>
3775 for (auto map : maps) {
3776 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3777 if (!attr)
3778 return failure();
3780 for (auto result : attr.getAffineMap().getResults()) {
3781 auto dim = dyn_cast<AffineDimExpr>(result);
3782 if (!dim)
3783 return failure();
3784 pos.push_back(dim.getPosition());
3785 }
3786 positions.push_back(pos);
3787 }
3788 return positions;
3789}
3790
3791/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3792SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3793 AffineExpr d0, d1, d2;
3794 SmallVector<AffineMap> indexingMaps;
3795 bindDims(context, d0, d1, d2);
3796 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3797 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3798 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3799 return indexingMaps;
3800}
3801
3802bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3803 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3804 if (!maps)
3805 return false;
3806 if (maps.size() != 3)
3807 return false;
3808 auto positions = getAffineResultPositions(maps);
3809 if (failed(positions))
3810 return false;
3811 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3812 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3813 (*positions)[2] == SmallVector<int64_t>{0, 1};
3814}
3815
3816SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3817 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3818 utils::IteratorType::parallel,
3819 utils::IteratorType::reduction};
3820}
3821
3822unsigned MatmulOp::getNumRegionArgs() { return 3; }
3823
3824std::string MatmulOp::getLibraryCallName() {
3825 return generateLibraryCallName(getOperation());
3826}
3827
3828bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3829
3830/// Check if the op has broadcast and/or transpose semantic. Returns true if
3831/// the user defined indexing maps are not equal to default map.
3832bool MatmulOp::hasUserDefinedMaps() {
3833 SmallVector<AffineMap, 3> defaultMaps =
3834 getDefaultIndexingMaps(this->getContext());
3835 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3836 return defaultMaps != explicitMaps;
3837}
3838
3839/// Implements the block region builder for the MatmulOp. This is called by
3840/// 'fillStructuredOpRegion'.
3841void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3842 ArrayRef<NamedAttribute> attrs,
3843 function_ref<InFlightDiagnostic()> emitError) {
3844 if (emitError && block.getNumArguments() != 3) {
3845 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3846 << block.getNumArguments();
3847 return;
3848 }
3849 assert(block.getNumArguments() == 3 &&
3850 "MatmulOp regionBuilder expects 3 args");
3851 RegionBuilderHelper helper(b, block);
3852 SmallVector<Value> yields;
3853
3854 TypeFn castVal = TypeFn::cast_signed;
3855 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3856 return attr.getName() == "cast";
3857 });
3858 if (castIter != attrs.end()) {
3859 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3860 castVal = attr.getValue();
3861 }
3862
3863 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3864 block.getArgument(0));
3865 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3866 block.getArgument(1));
3867 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3868 if (!value3)
3869 return;
3870 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3871 value3, emitError);
3872 if (!value4)
3873 return;
3874 yields.push_back(value4);
3875 helper.yieldOutputs(yields);
3876}
3877
3878/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3879/// broadcast map must include K dimension.
3880/// TODO: Strict inclusion of K dimension in the broadcast map is not
3881/// necessary for both input matrices simultaneously. We can relax this
3882/// condition to have K dimension for one input matrix map and infer the K
3883/// dimension for other input matrix map from the one already having K
3884/// dimension.
3885bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3886 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3887 AffineExpr expr = bcastMap.getResult(0);
3888 // Invalid map if the common dimension of matmul not found.
3889 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3890}
3891
3892static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3893 if (parser.parseOptionalKeyword("indexing_maps"))
3894 return ArrayAttr{
3895 nullptr}; // Success in case indexing_maps was not provided.
3896
3897 ArrayAttr arrayAttr;
3898 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3899 return failure();
3900
3901 if (llvm::any_of(arrayAttr,
3902 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3903 return parser.emitError(parser.getCurrentLocation())
3904 << "element of indexing_maps array is not an affine_map";
3905
3906 return arrayAttr;
3907}
3908
3909ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3910 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3911 if (failed(indexingMapsAttr))
3912 return failure();
3913
3914 if (*indexingMapsAttr == nullptr) {
3915 auto indexingMapAttrs = llvm::map_to_vector(
3916 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3917 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3918 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3919 }
3920
3921 result.addAttribute("indexing_maps", *indexingMapsAttr);
3922 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3923 MatmulOp::getRegionBuilder());
3924}
3925
3926void MatmulOp::print(OpAsmPrinter &p) {
3927 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3928 MatmulOp::getDefaultIndexingMaps(getContext()),
3929 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3930 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3931 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3932
3933 std::array<StringRef, 3> elidedAttrs = {
3934 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3935 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3936 elidedAttrs);
3937}
3938
3939/// Verify the user defined indexing maps.
3940LogicalResult MatmulOp::verify() {
3941 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3942 if (!hasUserDefinedMaps())
3943 return success();
3944
3945 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3946 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3947 return failure();
3948 }
3949 return success();
3950}
3951
3952LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3953 return memref::foldMemRefCast(*this);
3954}
3955
3956void MatmulOp::getEffects(
3957 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3958 &effects) {
3959 if (hasPureTensorSemantics())
3960 return;
3961 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3962}
3963
3964Speculation::Speculatability MatmulOp::getSpeculatability() {
3965 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3966}
3967
3968SmallVector<AffineMap>
3969MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3970 AffineExpr d0, d1, d2;
3971 MLIRContext *context = builder.getContext();
3972 bindDims(context, d0, d1, d2);
3973 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3974 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3975 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3976 return {mapLHS, mapRHS, mapOut};
3977}
3978
3980 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3981 if (!maps)
3982 return false;
3983 if (maps.size() != 3)
3984 return false;
3985 auto positions = getAffineResultPositions(maps);
3986 if (failed(positions))
3987 return false;
3988 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
3989 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3990 (*positions)[2] == SmallVector<int64_t>{0, 1};
3991}
3992
3995 ValueRange inputs, ValueRange outputs,
3996 ArrayRef<NamedAttribute> attributes) {
3997 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3998 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
3999}
4000
4003 ValueRange inputs, ValueRange outputs,
4004 ArrayRef<NamedAttribute> attributes) {
4005 OperationState state(location, getOperationName());
4006 build(builder, state, inputs, outputs, attributes);
4007 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4008 assert(res && "builder didn't return the right type");
4009 return res;
4010}
4011
4014 TypeRange resultTensorTypes,
4015 ValueRange inputs, ValueRange outputs,
4016 ArrayRef<NamedAttribute> attributes) {
4017 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4018 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4019}
4020
4023 TypeRange resultTensorTypes, ValueRange inputs,
4024 ValueRange outputs,
4025 ArrayRef<NamedAttribute> attributes) {
4026 OperationState state(location, getOperationName());
4027 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4028 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4029 assert(res && "builder didn't return the right type");
4030 return res;
4031}
4032
4035 TypeRange resultTensorTypes,
4036 ValueRange inputs, ValueRange outputs,
4037 Attribute cast,
4038 ArrayRef<NamedAttribute> attributes) {
4039 result.addAttribute("cast", cast);
4040 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4041 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4042}
4043
4046 TypeRange resultTensorTypes, ValueRange inputs,
4047 ValueRange outputs, Attribute cast,
4048 ArrayRef<NamedAttribute> attributes) {
4049 OperationState state(location, getOperationName());
4050 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4051 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4052 assert(res && "builder didn't return the right type");
4053 return res;
4054}
4055
4057 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4059 op->getAttr("indexing_maps"));
4060}
4061
4063MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4064 AffineExpr d0, d1, d2;
4065 MLIRContext *context = builder.getContext();
4066 bindDims(context, d0, d1, d2);
4067 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4068 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4069 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4070 return {mapLHS, mapRHS, mapOut};
4071}
4072
4074 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4075 if (!maps)
4076 return false;
4077 if (maps.size() != 3)
4078 return false;
4079 auto positions = getAffineResultPositions(maps);
4080 if (failed(positions))
4081 return false;
4082 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4083 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4084 (*positions)[2] == SmallVector<int64_t>{0, 1};
4085}
4086
4089 ValueRange inputs, ValueRange outputs,
4090 ArrayRef<NamedAttribute> attributes) {
4091 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4092 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4093}
4094
4097 ValueRange inputs, ValueRange outputs,
4098 ArrayRef<NamedAttribute> attributes) {
4099 OperationState state(location, getOperationName());
4100 build(builder, state, inputs, outputs, attributes);
4101 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4102 assert(res && "builder didn't return the right type");
4103 return res;
4104}
4105
4108 TypeRange resultTensorTypes,
4109 ValueRange inputs, ValueRange outputs,
4110 ArrayRef<NamedAttribute> attributes) {
4111 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4112 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4113}
4114
4117 TypeRange resultTensorTypes, ValueRange inputs,
4118 ValueRange outputs,
4119 ArrayRef<NamedAttribute> attributes) {
4120 OperationState state(location, getOperationName());
4121 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4122 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4123 assert(res && "builder didn't return the right type");
4124 return res;
4125}
4126
4129 TypeRange resultTensorTypes,
4130 ValueRange inputs, ValueRange outputs,
4131 Attribute cast,
4132 ArrayRef<NamedAttribute> attributes) {
4133 result.addAttribute("cast", cast);
4134 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4135 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4136}
4137
4140 TypeRange resultTensorTypes, ValueRange inputs,
4141 ValueRange outputs, Attribute cast,
4142 ArrayRef<NamedAttribute> attributes) {
4143 OperationState state(location, getOperationName());
4144 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4145 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4146 assert(res && "builder didn't return the right type");
4147 return res;
4148}
4149
4151 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4153 op->getAttr("indexing_maps"));
4154}
4155
4157BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4158 AffineExpr d0, d1, d2, d3;
4159 MLIRContext *context = builder.getContext();
4160 bindDims(context, d0, d1, d2, d3);
4161 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4162 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4163 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4164 return {mapLHS, mapRHS, mapOut};
4165}
4166
4168 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4169 if (!maps)
4170 return false;
4171 if (maps.size() != 3)
4172 return false;
4173 auto positions = getAffineResultPositions(maps);
4174 if (failed(positions))
4175 return false;
4176 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4177 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4178 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4179}
4180
4182 OpBuilder &builder, OperationState &result, ValueRange inputs,
4183 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4184 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4185 BatchMatmulOp::getRegionBuilder(),
4186 getDefaultIndexingMaps(builder));
4187}
4188
4191 ValueRange inputs, ValueRange outputs,
4192 ArrayRef<NamedAttribute> attributes) {
4193 OperationState state(location, getOperationName());
4194 build(builder, state, inputs, outputs, attributes);
4195 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4196 assert(res && "builder didn't return the right type");
4197 return res;
4198}
4199
4201 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4202 ValueRange inputs, ValueRange outputs,
4203 ArrayRef<NamedAttribute> attributes) {
4204 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4205 BatchMatmulOp::getRegionBuilder(),
4206 getDefaultIndexingMaps(builder));
4207}
4208
4211 TypeRange resultTensorTypes, ValueRange inputs,
4212 ValueRange outputs,
4213 ArrayRef<NamedAttribute> attributes) {
4214 OperationState state(location, getOperationName());
4215 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4216 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4217 assert(res && "builder didn't return the right type");
4218 return res;
4219}
4220
4222 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4223 ValueRange inputs, ValueRange outputs, Attribute cast,
4224 ArrayRef<NamedAttribute> attributes) {
4225 result.addAttribute("cast", cast);
4226 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4227 BatchMatmulOp::getRegionBuilder(),
4228 getDefaultIndexingMaps(builder));
4229}
4230
4233 TypeRange resultTensorTypes, ValueRange inputs,
4234 ValueRange outputs, Attribute cast,
4235 ArrayRef<NamedAttribute> attributes) {
4236 OperationState state(location, getOperationName());
4237 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4238 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4239 assert(res && "builder didn't return the right type");
4240 return res;
4241}
4242
4244 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4246 op->getAttr("indexing_maps"));
4247}
4248
4250BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4251 AffineExpr d0, d1, d2, d3;
4252 MLIRContext *context = builder.getContext();
4253 bindDims(context, d0, d1, d2, d3);
4254 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4255 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4256 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4257 return {mapLHS, mapRHS, mapOut};
4258}
4259
4261 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4262 if (!maps)
4263 return false;
4264 if (maps.size() != 3)
4265 return false;
4266 auto positions = getAffineResultPositions(maps);
4267 if (failed(positions))
4268 return false;
4269 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4270 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4271 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4272}
4273
4275 OpBuilder &builder, OperationState &result, ValueRange inputs,
4276 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4277 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4278 BatchMatmulOp::getRegionBuilder(),
4279 getDefaultIndexingMaps(builder));
4280}
4281
4284 ValueRange inputs, ValueRange outputs,
4285 ArrayRef<NamedAttribute> attributes) {
4286 OperationState state(location, getOperationName());
4287 build(builder, state, inputs, outputs, attributes);
4288 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4289 assert(res && "builder didn't return the right type");
4290 return res;
4291}
4292
4294 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4295 ValueRange inputs, ValueRange outputs,
4296 ArrayRef<NamedAttribute> attributes) {
4297 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4298 BatchMatmulOp::getRegionBuilder(),
4299 getDefaultIndexingMaps(builder));
4300}
4301
4304 TypeRange resultTensorTypes, ValueRange inputs,
4305 ValueRange outputs,
4306 ArrayRef<NamedAttribute> attributes) {
4307 OperationState state(location, getOperationName());
4308 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4309 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4310 assert(res && "builder didn't return the right type");
4311 return res;
4312}
4313
4315 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4316 ValueRange inputs, ValueRange outputs, Attribute cast,
4317 ArrayRef<NamedAttribute> attributes) {
4318 result.addAttribute("cast", cast);
4319 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4320 BatchMatmulOp::getRegionBuilder(),
4321 getDefaultIndexingMaps(builder));
4322}
4323
4326 TypeRange resultTensorTypes, ValueRange inputs,
4327 ValueRange outputs, Attribute cast,
4328 ArrayRef<NamedAttribute> attributes) {
4329 OperationState state(location, getOperationName());
4330 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4331 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4332 assert(res && "builder didn't return the right type");
4333 return res;
4334}
4335
4337 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4339 op->getAttr("indexing_maps"));
4340}
4341
4342//===----------------------------------------------------------------------===//
4343// ContractOp
4344//===----------------------------------------------------------------------===//
4345
4346SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4347 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4348 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4349 // domains are all the same, and each implements a projected permutation.
4350 // Each iteration space dim must occur for at least one operand and either
4351 // takes part in a contraction/reduction or else has parallel iteration type.
4352 // We have that a dim is a contraction/reduction dim if and only if the dim
4353 // occurs for the output operand. We use this fact for fast inference:
4354 // NB: In case we allow dims to occur solely for one input, the above still
4355 // holds: per the einsum semantics, these are reduction dims as well.
4356 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4357 for (auto result : outAffineMap.getResults()) {
4358 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4359 assert(dimExpr && "affine_map is a projected permutation");
4360 dimsInOutput[dimExpr.getPosition()] = true;
4361 }
4362
4364 for (auto dimOccursInOutput : dimsInOutput)
4365 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4366 : utils::IteratorType::reduction);
4367
4368 return iteratorTypes;
4369}
4370
4371unsigned ContractOp::getNumRegionArgs() { return 3; }
4372
4373/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4374void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4375 ArrayRef<NamedAttribute> attrs,
4376 function_ref<InFlightDiagnostic()> emitError) {
4377 if (emitError && block.getNumArguments() != 3) {
4378 emitError() << "ContractOp regionBuilder expects 3 args, got "
4379 << block.getNumArguments();
4380 return;
4381 }
4382 assert(block.getNumArguments() == 3 &&
4383 "ContractOp regionBuilder expects 3 args");
4384 RegionBuilderHelper helper(b, block);
4385
4386 TypeFn castSignedness = TypeFn::cast_signed;
4387 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4388 return attr.getName() == "cast";
4389 });
4390 if (castIter != attrs.end()) {
4391 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4392 castSignedness = attr.getValue();
4393 }
4394
4395 // TODO: Support fields with operators besides mult & add.
4396 Type outType = block.getArgument(2).getType();
4397 Value lhsAtOutType =
4398 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4399 Value rhsAtOutType =
4400 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4401 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4402 rhsAtOutType, emitError);
4403 if (!productAtOutType)
4404 return;
4405 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4406 productAtOutType, emitError);
4407 if (!result)
4408 return;
4409 helper.yieldOutputs({result});
4410}
4411
4412ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4413 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4414 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4415 return parser.emitError(parser.getCurrentLocation(),
4416 "expected 'indexing_maps' attribute");
4417 result.addAttribute("indexing_maps", *indexingMapsAttr);
4418
4419 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4420 regionBuilder);
4421}
4422
4423void ContractOp::print(OpAsmPrinter &p) {
4424 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4426 p, getOperation(), getInputs(), getOutputs(),
4427 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4428}
4429
4430LogicalResult ContractOp::verify() {
4431 int iterationSpaceDims = -1;
4432 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4433 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4434 // access an input operand (so occurrence count can be at most 2) and
4435 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4436 SmallVector<size_t> inOccurrences;
4437 SmallVector<size_t> outOccurrences;
4438
4439 // A helper so that for each operand's affine_map and type we check that ...
4440 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4441 bool isInput) -> LogicalResult {
4442 // ... the affine_map is a projected permutation;
4443 if (!affineMap.isProjectedPermutation())
4444 return emitError("provided affine_map is not a projected permutation");
4445
4446 // ... the rank of the affine_map's results and corresponding type match;
4447 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4448 if (affineMap.getNumResults() != shapedType.getRank())
4449 return emitError("ranks of shaped operand and results of corresponding "
4450 "affine_map differ");
4451 } else if (affineMap.getNumResults() != 0) {
4452 return emitError("affine_map specifies shaped access while operand has "
4453 "non-shaped type");
4454 }
4455
4456 // ... the rank of the affine_map's domain is the same as those seen prior;
4457 if (iterationSpaceDims == -1) {
4458 iterationSpaceDims = affineMap.getNumDims();
4459 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4460 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4461 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4462 return emitError("iteration spaces of provided affine_maps differ");
4463 }
4464
4465 // ... update counts of dims used to access either an input or the output.
4466 for (AffineExpr affineExpr : affineMap.getResults()) {
4467 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4468 if (!affineDimExpr)
4469 llvm_unreachable("affine_map is a projected permutation");
4470
4471 if (isInput)
4472 inOccurrences[affineDimExpr.getPosition()] += 1;
4473 else
4474 outOccurrences[affineDimExpr.getPosition()] += 1;
4475 }
4476
4477 return success();
4478 };
4479
4480 for (auto &&[affineMap, operandType, isInput] :
4481 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4482 SmallVector<bool>{true, true, false})) {
4483 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4484 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4485 }
4486
4487 bool hasContractingDim = false;
4488 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4489 size_t inOccCount = inOccurrences[dimIndex];
4490 size_t outOccCount = outOccurrences[dimIndex];
4491
4492 // We have a contracting dim if and only if ...
4493 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4494
4495 if (inOccCount == 0 && outOccCount == 0)
4496 return emitError() << "iteration space dim at index " << dimIndex
4497 << " not used to access any operand";
4498
4499 // NB: We disallow a dim which occurs for only one input operand and not
4500 // for the output. In terms of einsum semantics such dims have a
4501 // sensible meaning - namely an additional reduction per each such dim.
4502 // By contrast, the ContractionOpInterface does not know about this
4503 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4504 // while vector.contract's verifier accepts dims of this kind many of
4505 // its lowerings give up on encountering these dims.
4506 // TODO: Remove following once we have comprehensive support for input-only
4507 // reduction dims, at both the linalg- and vector-dialect levels.
4508 if (inOccCount == 1 && outOccCount != 1)
4509 return emitError()
4510 << "iteration space dim at index " << dimIndex
4511 << " is neither a contracting dim nor of parallel iteration type";
4512 }
4513
4514 if (!hasContractingDim)
4515 return emitError("'indexing_maps' do not specify a contracting dimension");
4516
4517 return success();
4518}
4519
4520LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4521 return memref::foldMemRefCast(*this);
4522}
4523
4524void ContractOp::getEffects(
4525 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4526 &effects) {
4527 if (hasPureTensorSemantics())
4528 return;
4529 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4530}
4531
4532Speculation::Speculatability ContractOp::getSpeculatability() {
4533 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4534}
4535
4536//===----------------------------------------------------------------------===//
4537// Implementation of BatchMatmulOp
4538//===----------------------------------------------------------------------===//
4539SmallVector<AffineMap>
4540BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4541 AffineExpr d0, d1, d2, d3;
4542 SmallVector<AffineMap> indexingMaps;
4543 bindDims(context, d0, d1, d2, d3);
4544 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4545 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4546 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4547 return indexingMaps;
4548}
4549
4550bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4551 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4552 if (!maps)
4553 return false;
4554 if (maps.size() != 3)
4555 return false;
4556 auto positions = getAffineResultPositions(maps);
4557 if (failed(positions))
4558 return false;
4559 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4560 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4561 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4562}
4563
4564SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4565 return SmallVector<utils::IteratorType>{
4566 utils::IteratorType::parallel, utils::IteratorType::parallel,
4567 utils::IteratorType::parallel, utils::IteratorType::reduction};
4568}
4569
4570unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4571
4572std::string BatchMatmulOp::getLibraryCallName() {
4573 return generateLibraryCallName(getOperation());
4574}
4575
4576/// Check if the op has broadcast and/or transpose semantic. Returns true if
4577/// the user defined indexing maps are not equal to default map.
4578bool BatchMatmulOp::hasUserDefinedMaps() {
4579 SmallVector<AffineMap, 3> defaultMaps =
4580 getDefaultIndexingMaps(this->getContext());
4581 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4582 return defaultMaps != explicitMaps;
4583}
4584
4585/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4586/// broadcast map must include K dimension.
4587/// TODO: Strict inclusion of K dimension in the broadcast map is not
4588/// necessary for both input matrices simultaneously. We can relax this
4589/// condition to have K dimension for one input matrix map and infer the K
4590/// dimension for other input matrix map from the one already having K
4591/// dimension.
4592bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4593 assert(bcastMap.getNumResults() < 3 &&
4594 "Expected less than 3 result dim expr.");
4595 bool isValid = false;
4596 enum Indices { batchPos, mPos, nPos, kPos };
4597 if (bcastMap.getNumResults() == 1) {
4598 AffineExpr expr = bcastMap.getResult(0);
4599 isValid = expr.isFunctionOfDim(kPos);
4600 } else if (bcastMap.getNumResults() == 2) {
4601 AffineExpr expr0 = bcastMap.getResult(0);
4602 AffineExpr expr1 = bcastMap.getResult(1);
4603 isValid =
4604 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4605 expr0.isFunctionOfDim(mPos)) &&
4606 expr1.isFunctionOfDim(kPos))
4607 : ((expr0.isFunctionOfDim(batchPos) &&
4608 expr1.isFunctionOfDim(kPos)) ||
4609 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4610 }
4611 return isValid;
4612}
4613
4614void BatchMatmulOp::regionBuilder(
4615 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4616 function_ref<InFlightDiagnostic()> emitError) {
4617 if (emitError && block.getNumArguments() != 3) {
4618 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4619 << block.getNumArguments();
4620 return;
4621 }
4622 assert(block.getNumArguments() == 3 &&
4623 "BatchMatmulOp regionBuilder expects 3 args");
4624 RegionBuilderHelper helper(b, block);
4625 SmallVector<Value> yields;
4626
4627 TypeFn castVal = TypeFn::cast_signed;
4628 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4629 return attr.getName() == "cast";
4630 });
4631 if (castIter != attrs.end()) {
4632 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4633 castVal = attr.getValue();
4634 }
4635
4636 auto toType = block.getArgument(2).getType();
4637 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4638 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4639 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4640 Value addVal =
4641 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4642 yields.push_back(addVal);
4643 helper.yieldOutputs(yields);
4644}
4645
4646ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4647 SmallVector<Attribute, 3> indexingMapsAttr;
4648 Attribute mapAttr;
4649 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4650 if (parser.parseEqual())
4651 return failure();
4652
4653 if (parser.parseLSquare())
4654 return failure();
4655
4656 do {
4657 if (parser.parseAttribute(mapAttr))
4658 return failure();
4659 if (!isa<AffineMapAttr>(mapAttr)) {
4660 return parser.emitError(parser.getCurrentLocation(),
4661 "expected affine map attribute");
4662 }
4663 indexingMapsAttr.push_back(mapAttr);
4664
4665 if (parser.parseOptionalComma())
4666 break;
4667 } while (true);
4668
4669 if (parser.parseRSquare())
4670 return failure();
4671 }
4672 // Initialize indexingMaps, if not supplied explicitly.
4673 if (indexingMapsAttr.empty()) {
4674 indexingMapsAttr = llvm::map_to_vector(
4675 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4676 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4677 }
4678 result.addAttribute("indexing_maps",
4679 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4680
4681 return ::parseNamedStructuredOp(parser, result,
4682 BatchMatmulOp::getNumRegionArgs(),
4683 BatchMatmulOp::getRegionBuilder());
4684}
4685
4686void BatchMatmulOp::print(OpAsmPrinter &p) {
4687 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4688 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4689 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4690 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4691 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4692
4693 std::array<StringRef, 3> elidedAttrs = {
4694 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4695 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4696 elidedAttrs);
4697}
4698
4699/// Verify the user defined indexing maps.
4700LogicalResult BatchMatmulOp::verify() {
4701 // Verification of pure batch_matmul is handled by
4702 // verifyStructuredOpInterface().
4703 if (!hasUserDefinedMaps())
4704 return success();
4705
4706 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4708 return failure();
4709 }
4710 return success();
4711}
4712
4713LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4714 SmallVectorImpl<OpFoldResult> &) {
4715 return memref::foldMemRefCast(*this);
4716}
4717
4718void BatchMatmulOp::getEffects(
4719 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4720 &effects) {
4721 if (hasPureTensorSemantics())
4722 return;
4723 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4724}
4725
4726Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4727 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4728}
4729
4730//===----------------------------------------------------------------------===//
4731// ElementwiseOp
4732//===----------------------------------------------------------------------===//
4733//
4734namespace {
4735struct ArityGroupAndKind {
4736 // The enum class {Unary, Binary, Ternary, ..}
4737 ElementwiseArityGroup arityGroup;
4738
4739 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4740 union Kind {
4741 UnaryFn unaryFn;
4742 BinaryFn binaryFn;
4743 TernaryFn ternaryFn;
4744 } kind;
4745};
4746
4747unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4748 return static_cast<unsigned>(arityGroup);
4749}
4750} // namespace
4751
4752static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4753 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4754 constexpr int lastBinary =
4755 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4756 constexpr int lastTernary =
4757 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4758
4759 int val = static_cast<int>(kind);
4760 ArityGroupAndKind result;
4761
4762 if (val < lastUnary) {
4763 result.arityGroup = ElementwiseArityGroup::Unary;
4764 result.kind.unaryFn = static_cast<UnaryFn>(val);
4765 return result;
4766 }
4767 if (val < lastBinary) {
4768 result.arityGroup = ElementwiseArityGroup::Binary;
4769 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4770 return result;
4771 }
4772 if (val >= lastTernary) {
4773 llvm_unreachable("unhandled ElementwiseFn");
4774 }
4775 result.arityGroup = ElementwiseArityGroup::Ternary;
4776 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4777 return result;
4778}
4779
4780SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4781 auto rank = getResultRank();
4782 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4783}
4784
4786ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4787 MLIRContext *context) {
4788 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4789 return SmallVector<AffineMap>(numMaps, map);
4790}
4791
4792ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4793 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4794 Attribute attr;
4795 mlir::linalg::ElementwiseKind elemwiseKindVal;
4796 if (parser.parseKeyword("kind") || parser.parseEqual())
4797 return failure();
4798
4799 if (succeeded(parser.parseAttribute(attr))) {
4800 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4801 if (!elemwiseKindAttr)
4802 return parser.emitError(parser.getCurrentLocation(),
4803 "expected ElementwiseKind attribute");
4804 elemwiseKindVal = elemwiseKindAttr.getValue();
4805 } else {
4806 return parser.emitError(parser.getCurrentLocation(),
4807 "expected operation 'kind' attribute");
4808 }
4809 result.addAttribute(
4810 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4811
4812 // Parse optional `indexing_maps`
4813 SmallVector<Attribute, 3> indexingMapsAttr;
4814 Attribute mapAttr;
4815 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4816 if (parser.parseEqual())
4817 return failure();
4818 if (parser.parseLSquare())
4819 return failure();
4820 do {
4821 if (parser.parseAttribute(mapAttr))
4822 return failure();
4823 if (!isa<AffineMapAttr>(mapAttr))
4824 return parser.emitError(parser.getCurrentLocation(),
4825 "expected affine map attribute");
4826 indexingMapsAttr.push_back(mapAttr);
4827 if (parser.parseOptionalComma())
4828 break;
4829 } while (true);
4830 if (parser.parseRSquare())
4831 return failure();
4832 }
4833 // At this stage of parsing the only way to infer number of region
4834 // args is through op kind, as input output tensors are not parsed yet.
4835 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4836 int numRegionArgs =
4837 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4838 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4839 ElementwiseOp::getRegionBuilder())) {
4840 return parser.emitError(parser.getCurrentLocation(),
4841 "unable to parse elemwise op");
4842 }
4843
4844 // Initialize indexingMaps, if not supplied explicitly.
4845 if (indexingMapsAttr.empty()) {
4846 // We need to infer the numDims of the indexing maps from the output
4847 // type which is already parsed by now.
4848 auto resultType = result.operands[result.operands.size() - 1].getType();
4849 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4850 if (!shapedType)
4851 return parser.emitError(parser.getCurrentLocation(),
4852 "return type needs to be shaped type");
4853 auto numDims = shapedType.getRank();
4854 indexingMapsAttr = llvm::map_to_vector(
4855 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4856 parser.getContext()),
4857 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4858 }
4859
4860 result.addAttribute("indexing_maps",
4861 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4862 return success();
4863}
4864
4865void ElementwiseOp::print(OpAsmPrinter &p) {
4866 p << " kind=";
4867 p.printAttribute(getKindAttr());
4868 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4869 "indexing_maps"};
4870 unsigned arity =
4871 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4872 unsigned numDims = getResultRank();
4873
4874 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4875 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4876 getContext()),
4877 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4878
4879 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4880 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4881
4882 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4883 elidedAttrs);
4884}
4885
4886/// Implements the block region builder for the ElementwiseOp. This is called by
4887/// 'fillStructuredOpRegion'.
4888void ElementwiseOp::regionBuilder(
4889 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4890 function_ref<InFlightDiagnostic()> emitError) {
4891 ElementwiseKind elemwiseKind;
4892 for (auto attr : attrs) {
4893 if (attr.getName() == b.getStringAttr("kind")) {
4894 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4895 assert(kindAttr && "op kind attribute incorrectly set");
4896 elemwiseKind = kindAttr.getValue();
4897 break;
4898 }
4899 }
4900
4901 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4902 auto arityGroup = groupAndKind.arityGroup;
4903 auto kind = groupAndKind.kind;
4904 if (emitError && block.getNumArguments() !=
4905 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4906 emitError() << "Elementwise regionBuilder expects "
4907 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4908 << block.getNumArguments();
4909 return;
4910 }
4911 assert(block.getNumArguments() ==
4912 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4913 && "Elementwise regionBuilder number of block args mismatch");
4914
4915 RegionBuilderHelper helper(b, block);
4916 SmallVector<Value> yields;
4917 Value result;
4918
4919 if (arityGroup == ElementwiseArityGroup::Unary) {
4920 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4921
4922 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4923 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4924 block.getArgument(1));
4925
4926 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4927 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4928 block.getArgument(1), block.getArgument(2));
4929
4930 } else {
4931 assert(false && "found unhandled category in elemwise");
4932 }
4933
4934 yields.push_back(result);
4935 helper.yieldOutputs(yields);
4936}
4937
4938LogicalResult ElementwiseOp::fold(FoldAdaptor,
4939 SmallVectorImpl<OpFoldResult> &) {
4940 return memref::foldMemRefCast(*this);
4941}
4942
4943void ElementwiseOp::getEffects(
4944 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4945 &effects) {
4946 if (hasPureTensorSemantics())
4947 return;
4948 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4949}
4950
4951Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4952 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4953}
4954
4955//===----------------------------------------------------------------------===//
4956// PackOp/UnPackOp Common
4957//===----------------------------------------------------------------------===//
4958
4959template <typename OpTy, typename>
4960SmallVector<int64_t>
4962 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4963 ? packOrUnPack.getDestType()
4964 : packOrUnPack.getSourceType();
4965 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4966 ? packOrUnPack.getSourceType()
4967 : packOrUnPack.getDestType();
4969 packedType.getShape().take_front(unpackedType.getRank()));
4970 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4972 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4973 }
4974 return result;
4975}
4980
4981// Given the (potentially) updated packed type, `newPackedTy`, generates an
4982// updated mixed-tile-sizes attribute. A tile size is updated only
4983// when:
4984// * a dim from newPackedTy is static, and
4985// * the corresponding size from mixedTiles is still dynamic.
4986// Otherwise, the original tile size is preserved.
4987// Note - packed-type-dim and mixed-tile-size should always match!
4990 SmallVector<OpFoldResult> mixedTiles) {
4991 SmallVector<OpFoldResult> newMixedTileSizes;
4992 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4993 .getShape()
4994 .take_back(mixedTiles.size()),
4995 mixedTiles)) {
4996 int64_t shape = std::get<0>(it);
4997 if (shape == ShapedType::kDynamic) {
4998 newMixedTileSizes.push_back(std::get<1>(it));
4999 continue;
5000 }
5001
5002 // If the current result dim is static, update the dynamic mixed-size
5003 // (provided the original value is dynamic).
5004 OpFoldResult tile = std::get<1>(it);
5005 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5006 // Already a constant
5007 newMixedTileSizes.push_back(tile);
5008 } else {
5009 assert(getConstantIntValue(tile).value() == shape &&
5010 "tile size and dim size don't match!");
5011 newMixedTileSizes.push_back(
5012 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5013 }
5014 }
5015
5016 return newMixedTileSizes;
5017}
5018
5019template <typename OpTy>
5020static LogicalResult
5022 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5023 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5024 "applies to only pack or unpack operations");
5025 int64_t destRank = op.getDestRank();
5026 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5027 reifiedReturnShapes[0] =
5028 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
5029 return success();
5030}
5031
5032template <typename OpTy>
5034 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5035 "applies to only pack or unpack operations");
5036 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5037 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5038 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5039 assert(tiles.size() == dimsToTile.size() &&
5040 "tiles must match indices of dimension to block");
5041 // bind the dimension `i` with the tile factor.
5042 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5043 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5044 return dimAndTileMapping;
5045}
5046
5047template <typename OpTy>
5049 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5050 "applies to only pack or unpack operations");
5051 Builder builder(op);
5052 SmallVector<OpFoldResult> mixedInnerTiles;
5053 unsigned dynamicValIndex = 0;
5054 for (int64_t staticTile : op.getStaticInnerTiles()) {
5055 if (ShapedType::isStatic(staticTile))
5056 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5057 else
5058 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5059 }
5060 return mixedInnerTiles;
5061}
5062
5063template <typename OpTy>
5065 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5066 "applies to only pack or unpack operations");
5067 SmallVector<Value> dynamicTiles;
5068 SmallVector<int64_t> staticTiles;
5069 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5070 return staticTiles;
5071}
5072
5073/// Returns true if `dimsPos` is invalid. It is invalid when:
5074/// a) It contains duplicate.
5075/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5076/// c) The number of elements in `dimsPos` is > than `rank`.
5078 size_t rank) {
5079 size_t dimsPosSize = dimsPos.size();
5080 if (dimsPosSize > rank)
5081 return true;
5082 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5083 if (dimsPosSize != uniqued.size())
5084 return true;
5085 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5086 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5087 });
5088}
5089
5090template <typename OpTy>
5091static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5092 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5093 "applies to only pack or unpack operations");
5094 Operation *op = packOrUnPack.getOperation();
5095
5096 // Return true if we have a zero-value tile.
5097 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5098 return llvm::any_of(tiles, isZeroInteger);
5099 };
5100
5101 // Verify tiles. Do not allow zero tiles.
5102 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5103 if (hasZeros(mixedTiles))
5104 return op->emitError("invalid zero tile factor");
5105
5106 // Verify inner_dims_pos and outer_dims_perm.
5107 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5108 ? packOrUnPack.getSourceType()
5109 : packOrUnPack.getDestType();
5110 size_t unpackedRank = unpackedType.getRank();
5111 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5112 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5113 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5114 return op->emitError("invalid inner_dims_pos vector");
5115 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5116 return op->emitError("invalid outer_dims_perm vector");
5117 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5118 return op->emitError("outer_dims_perm must be a permutation or empty");
5119
5120 // Tiling factors must be less than or equal to the input rank for pack (or
5121 // output rank for unpack), and must match the number of `inner_dims_pos`.
5122 if (mixedTiles.size() > unpackedRank) {
5123 return op->emitError("tiling factors must be less than or equal to the "
5124 "input rank for pack or output rank for unpack");
5125 }
5126 if (mixedTiles.size() != innerDimsPos.size()) {
5127 return op->emitError(
5128 "tiling factors must equal the number of dimensions to tile");
5129 }
5130
5131 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5132 ? packOrUnPack.getDestType()
5133 : packOrUnPack.getSourceType();
5134 size_t packedRank = packedType.getRank();
5135 // Require output rank to match input rank + number of blocking factors.
5136 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5137 if (expectedPackedRank != packedRank) {
5138 return op->emitError(
5139 "packed rank != (unpacked rank + num tiling factors), got ")
5140 << packedRank << " != " << expectedPackedRank;
5141 }
5142
5143 // Verify result shape is greater than the minimum expected
5144 // by the pack operation, and that the output shape
5145 // represents full tiles.
5146 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5147 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5148 if (!llvm::all_of(
5149 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5150 mixedTiles),
5151 [](std::tuple<int64_t, OpFoldResult> it) {
5152 int64_t shape = std::get<0>(it);
5153 if (Attribute attr =
5154 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5155 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5156 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5157 return shape == staticTileSize;
5158 }
5159 return ShapedType::isDynamic(shape);
5160 })) {
5161 return op->emitError("mismatch in inner tile sizes specified and shaped of "
5162 "tiled dimension in the packed type");
5163 }
5164 if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
5165 packedType.getShape()))) {
5166 return op->emitError("expected ")
5167 << expectedPackedType << " for the packed domain value, got "
5168 << packedType;
5169 }
5170 return success();
5171}
5172
5173namespace {
5174/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5175/// various permutations to the op.
5176// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5177// these. These may or may not become true foldings / canonicalizations
5178// depending on how aggressive we want to be in automatically folding
5179// transposes.
5180struct PackOrUnPackTransposeResult {
5181 SmallVector<int64_t> innerDimsPos;
5182 SmallVector<OpFoldResult> innerTiles;
5183 SmallVector<int64_t> outerDimsPerm;
5184};
5185} // namespace
5186
5187template <typename OpTy>
5188static PackOrUnPackTransposeResult
5190 ArrayRef<int64_t> innerPermutation,
5191 ArrayRef<int64_t> outerPermutation) {
5192 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5193 "applies to only pack or unpack operations");
5194 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5195 "some permutation must be non-empty");
5196 PackOrUnPackTransposeResult metadata;
5197 metadata.innerDimsPos =
5198 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5199 metadata.innerTiles =
5200 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5201 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5202 ? packOrUnPackOp.getSourceRank()
5203 : packOrUnPackOp.getDestRank();
5204 metadata.outerDimsPerm =
5205 packOrUnPackOp.getOuterDimsPerm().empty()
5206 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5207 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5208 if (!innerPermutation.empty()) {
5209 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5210 isPermutationVector(innerPermutation) &&
5211 "invalid inner permutation");
5212 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5213 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5214 }
5215 if (!outerPermutation.empty()) {
5216 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5217 isPermutationVector(outerPermutation) &&
5218 "invalid outer permutation");
5219 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5220 }
5221 return metadata;
5222}
5223
5224//===----------------------------------------------------------------------===//
5225// PackOp
5226//===----------------------------------------------------------------------===//
5227
5228void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5229 setNameFn(getResult(), "pack");
5230}
5231
5232void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5233 Value dest, ArrayRef<int64_t> innerDimsPos,
5234 ArrayRef<OpFoldResult> innerTiles,
5235 std::optional<Value> paddingValue,
5236 ArrayRef<int64_t> outerDimsPerm) {
5237 assert(innerDimsPos.size() == innerTiles.size() &&
5238 "number of tile sizes specified must match the specified number of "
5239 "original dimensions to be tiled");
5240 SmallVector<int64_t> staticTileSizes;
5241 SmallVector<Value> dynamicTileSizes;
5242 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5243 build(builder, state, dest.getType(), source, dest,
5244 paddingValue ? *paddingValue : nullptr,
5245 outerDimsPerm.empty() ? nullptr
5246 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5247 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5248 builder.getDenseI64ArrayAttr(staticTileSizes));
5249}
5250
5251LogicalResult
5252PackOp::reifyResultShapes(OpBuilder &builder,
5253 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5254 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5255}
5256
5257DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5258 return getDimAndTileMappingImpl(*this);
5259}
5260
5261SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5262 return getMixedTilesImpl(*this);
5263}
5264
5265SmallVector<int64_t> PackOp::getStaticTiles() {
5266 return getStaticTilesImpl(*this);
5267}
5268
5269ArrayRef<int64_t> PackOp::getAllOuterDims() {
5270 ShapedType inputType = getSourceType();
5271 int64_t inputRank = inputType.getRank();
5272 return getDestType().getShape().take_front(inputRank);
5273}
5274
5275SmallVector<int64_t> PackOp::getTiledOuterDims() {
5276 auto innerDimsPos = getInnerDimsPos();
5277 SmallVector<int64_t> outerDims(getAllOuterDims());
5279
5280 // Recover the original order of the outer dims.
5281 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5282 invertPermutationVector(outerDimPermInv);
5283 if (!outerDimPermInv.empty())
5284 applyPermutationToVector(outerDims, outerDimPermInv);
5285
5286 // Collect the outer dims corresponding to the tilled inner dims.
5287 for (auto index : innerDimsPos)
5288 res.push_back(outerDims[index]);
5289
5290 return res;
5291}
5292
5293bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5294 ArrayRef<int64_t> innerDimsPos,
5295 ArrayRef<int64_t> outputShape,
5296 ArrayRef<int64_t> outerDimsPerm,
5297 ArrayRef<OpFoldResult> innerTiles) {
5298 SmallVector<int64_t> outputTileSizes(
5299 outputShape.take_front(inputShape.size()));
5300 if (!outerDimsPerm.empty()) {
5301 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5302 "expected output and outer_dims_perm to have same size");
5303 applyPermutationToVector(outputTileSizes,
5304 invertPermutationVector(outerDimsPerm));
5305 }
5306 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5307 if (ShapedType::isDynamic(inputShape[pos]))
5308 continue;
5309 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5310
5311 if (!constantTile) {
5312 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5313 (inputShape[pos] % outputTileSizes[pos] != 0))
5314 return true;
5315 } else if (inputShape[pos] % (*constantTile) != 0) {
5316 return true;
5317 }
5318 }
5319 return false;
5320}
5321
5322bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5323 ArrayRef<int64_t> innerDimsPos,
5324 ArrayRef<int64_t> outputShape,
5325 ArrayRef<int64_t> outerDimsPerm,
5326 ArrayRef<OpFoldResult> innerTiles) {
5327 SmallVector<int64_t> outputTileSizes(
5328 outputShape.take_front(inputShape.size()));
5329 if (!outerDimsPerm.empty()) {
5330 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5331 "expected output and outer_dims_perm to have same size");
5332 applyPermutationToVector(outputTileSizes,
5333 invertPermutationVector(outerDimsPerm));
5334 }
5335 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5336 if (ShapedType::isDynamic(inputShape[pos]) ||
5337 ShapedType::isDynamic(outputTileSizes[pos]))
5338 return true;
5339 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5340 if (!constantTile)
5341 return true;
5342 if (inputShape[pos] % (*constantTile) != 0)
5343 return true;
5344 }
5345 return false;
5346}
5347
5348LogicalResult PackOp::verify() {
5350 return failure();
5351
5352 // Verify padding value, and bail out if the tile does not divide the
5353 // dimension fully. In the case of dynamic tile factors or dimensions, having
5354 // a partial tile is undefined behavior.
5355 auto paddingValue = getPaddingValue();
5356 if (paddingValue &&
5357 paddingValue.getType() != getSourceType().getElementType()) {
5358 return emitOpError("expected padding_value has ")
5359 << getSourceType().getElementType()
5360 << " but got: " << paddingValue.getType();
5361 }
5362
5363 if (!paddingValue &&
5364 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5365 getDestType().getShape(), getOuterDimsPerm(),
5366 getMixedTiles())) {
5367 return emitOpError(
5368 "invalid tile factor or output size provided. Only full tiles are "
5369 "supported when padding_value is not set");
5370 }
5371 return success();
5372}
5373
5374/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5375/// Value's to kDynamic, even if they are arith.constant values.
5379 for (auto o : ofrs) {
5380 // Have to do this first, as getConstantIntValue special-cases constants.
5381 if (llvm::dyn_cast_if_present<Value>(o))
5382 result.push_back(ShapedType::kDynamic);
5383 else
5384 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5385 }
5386 return result;
5387}
5388
5389/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
5390/// the packed type. Having a shared helper helps implement these two methods in
5391/// a way that ensures that they agree on which dimensions are dynamic.
5393 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
5394 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5395 SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
5396 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5397 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5398 continue;
5399 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5400 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5401 continue;
5402 }
5403 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5404 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5405 }
5406
5407 // Swap tile loops if outer_dims_perm is available.
5408 if (!outerDimsPerm.empty())
5409 applyPermutationToVector(resultShape, outerDimsPerm);
5410
5411 // Append the inner tile dimensions.
5412 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5413 return resultShape;
5414}
5415
5416SmallVector<OpFoldResult> PackOp::getResultShape(
5417 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5418 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5419 ArrayRef<int64_t> outerDimsPerm) {
5420 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5421
5422 AffineExpr s0, s1;
5423 bindSymbols(builder.getContext(), s0, s1);
5424 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5425 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5426 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5427 builder, loc, ceilDivExpr,
5428 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5429 }
5430 if (!outerDimsPerm.empty())
5431 applyPermutationToVector(resultDims, outerDimsPerm);
5432 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5433
5434 SmallVector<int64_t> resultTypeShape =
5436 asShapeWithAnyValueAsDynamic(innerTileSizes),
5437 innerDimsPos, outerDimsPerm);
5438
5439 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5440 // result type shape says it's a dynamic dim. This is needed as callers may
5441 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5442 // dynamic dims returned by that.
5443 for (unsigned i = 0; i < resultDims.size(); ++i) {
5444 if (ShapedType::isStatic(resultTypeShape[i]))
5445 continue;
5446 resultDims[i] =
5447 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5448 }
5449
5450 return resultDims;
5451}
5452
5453/// Get the expected packed type based on source type, tile factors, position of
5454/// the inner tiles and permutation of the outer tiled loop.
5455RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5456 ArrayRef<int64_t> innerTileSizes,
5457 ArrayRef<int64_t> innerDimsPos,
5458 ArrayRef<int64_t> outerDimsPerm) {
5460 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5461 return RankedTensorType::get(resultShape, sourceType.getElementType());
5462}
5463
5464Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5465 ArrayRef<OpFoldResult> innerTileSizes,
5466 ArrayRef<int64_t> innerDimsPos,
5467 ArrayRef<int64_t> outerDimsPerm) {
5468 AffineExpr dim0, dim1;
5469 bindDims(b.getContext(), dim0, dim1);
5470 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5471 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5472 {v1, v2});
5473 };
5474
5475 SmallVector<OpFoldResult> mixedSizes;
5476 for (auto [index, value] : llvm::enumerate(
5477 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5478 if (ShapedType::isDynamic(value))
5479 mixedSizes.push_back(
5480 tensor::DimOp::create(b, loc, source, index).getResult());
5481 else
5482 mixedSizes.push_back(b.getIndexAttr(value));
5483 }
5484 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5485 int64_t dimPos = std::get<0>(it);
5486 OpFoldResult tileSize = std::get<1>(it);
5487 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5488 }
5489 if (!outerDimsPerm.empty())
5490 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5491
5492 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5493 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5494 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5495}
5496
5497PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5498 ArrayRef<int64_t> innerPermutation,
5499 ArrayRef<int64_t> outerPermutation) {
5500 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5501 *this, innerPermutation, outerPermutation);
5502 Value transposedDest =
5503 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5504 metadata.innerDimsPos, metadata.outerDimsPerm);
5505 return PackOp::create(b, loc, getSource(), transposedDest,
5506 metadata.innerDimsPos, metadata.innerTiles,
5507 getPaddingValue(), metadata.outerDimsPerm);
5508}
5509
5510/// Returns true if the tiles and the tiled dims are constant.
5511template <typename OpTy>
5513 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5514 "applies to only pack or unpack operations");
5515 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5516 ? op.getDestType()
5517 : op.getSourceType();
5518 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5519 for (auto [dimDest, tile] : llvm::zip(
5520 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5521 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5522 if (!constTileSize || ShapedType::isDynamic(dimDest))
5523 return false;
5524 }
5525 return true;
5526}
5527
5528Speculation::Speculatability PackOp::getSpeculatability() {
5529 if (getPaddingValue())
5531
5532 // The verifier rejects already operations if we can statically prove that the
5533 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5534 // to have constant tiles and tiled inner dimensions.
5537
5539}
5540
5541// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5542// dimensions for pack and unpack.
5543static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5544 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5545 return false;
5546 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5547 return true;
5548 // Outer dims permutation is optional.
5549 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5550 // identity permutation.
5551 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5552 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5553}
5554
5555// Return true if pack and unpack have the same tiles.
5556// Same SSA values or same integer constants.
5557static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5558 auto packTiles = packOp.getMixedTiles();
5559 auto unPackTiles = unPackOp.getMixedTiles();
5560 if (packTiles.size() != unPackTiles.size())
5561 return false;
5562 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5563 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5564 return false;
5565 }
5566 return true;
5567}
5568
5569/// Returns true if the pack op does not need a padding value.
5570static bool paddingIsNotNeeded(PackOp op) {
5571 auto srcType = op.getSourceType();
5572 if (llvm::any_of(op.getInnerDimsPos(),
5573 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5574 return false;
5575 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5576 return false;
5577 return !PackOp::requirePaddingValue(
5578 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5579 op.getOuterDimsPerm(), op.getMixedTiles());
5580}
5581
5582/// Returns true if the `srcShape` or `destShape` is different from the one in
5583/// `packOp` and populates each with the inferred static shape.
5584static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5585 SmallVectorImpl<int64_t> &destShape) {
5586 bool changeNeeded = false;
5587 srcShape.assign(packOp.getSourceType().getShape().begin(),
5588 packOp.getSourceType().getShape().end());
5589 destShape.assign(packOp.getDestType().getShape().begin(),
5590 packOp.getDestType().getShape().end());
5591 llvm::SmallSetVector<int64_t, 4> innerDims;
5592 innerDims.insert_range(packOp.getInnerDimsPos());
5593 SmallVector<int64_t> inverseOuterDimsPerm;
5594 if (!packOp.getOuterDimsPerm().empty())
5595 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5596 int srcRank = packOp.getSourceRank();
5597 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5598 if (innerDims.contains(i))
5599 continue;
5600 int64_t srcPos = i;
5601 int64_t destPos = i;
5602 if (!inverseOuterDimsPerm.empty())
5603 destPos = inverseOuterDimsPerm[srcPos];
5604 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5605 ShapedType::isDynamic(destShape[destPos])) {
5606 continue;
5607 }
5608 int64_t size = srcShape[srcPos];
5609 if (ShapedType::isDynamic(size))
5610 size = destShape[destPos];
5611 srcShape[srcPos] = size;
5612 destShape[destPos] = size;
5613 changeNeeded = true;
5614 }
5615 return changeNeeded;
5616}
5617
5618LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5619 // Fold an pack(unpack(x)) to x.
5620 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5621 if (unPackOp.getSourceType() == packOp.getDestType() &&
5622 !packOp.getPaddingValue() &&
5623 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5624 haveSameTiles(packOp, unPackOp)) {
5625 rewriter.replaceOp(packOp, unPackOp.getSource());
5626 return success();
5627 }
5628 }
5629
5630 // Fold optional PaddingValue operand away if padding is not needed.
5631 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5632 rewriter.startOpModification(packOp);
5633 packOp.getPaddingValueMutable().clear();
5634 rewriter.finalizeOpModification(packOp);
5635 return success();
5636 }
5637
5638 // Insert tensor.cast ops if static shape inference is available..
5639 SmallVector<int64_t> srcShape, destShape;
5640 if (inferStaticShape(packOp, srcShape, destShape)) {
5641 Location loc = packOp.getLoc();
5642 Value source = packOp.getSource();
5643 if (srcShape != packOp.getSourceType().getShape()) {
5644 auto newSrcType = packOp.getSourceType().clone(srcShape);
5645 source =
5646 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5647 }
5648 Value dest = packOp.getDest();
5649 RankedTensorType originalResultType = packOp.getDestType();
5650 bool needUpdateDestType = (destShape != originalResultType.getShape());
5651 if (needUpdateDestType) {
5652 auto newDestType = packOp.getDestType().clone(destShape);
5653 dest =
5654 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5655 }
5656 rewriter.modifyOpInPlace(packOp, [&] {
5657 packOp.getSourceMutable().assign(source);
5658 packOp.getDestMutable().assign(dest);
5659 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5660 });
5661 // Insert a cast if needed
5662 if (needUpdateDestType) {
5663 rewriter.setInsertionPointAfter(packOp);
5664 auto castOp =
5665 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5666 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5667 }
5668 return success();
5669 }
5670
5671 return failure();
5672}
5673
5674template <typename PackOrUnpackOp>
5675static bool isLikePadUnPad(PackOrUnpackOp packOp,
5676 RankedTensorType packedTensorType) {
5677 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5678 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5679 "Function meant for pack/unpack");
5680 // This is a pad if packing only adds ones and we don't transpose dimensions.
5681
5682 // Check that we are not transposing any dimensions.
5683 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5684 int64_t numPackedDims = innerDimsPos.size();
5685 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5686 if (orderedDims != innerDimsPos) {
5687 // Dimensions don't happen in order.
5688 return false;
5689 }
5690
5691 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5692 int64_t packedRank = packedTensorType.getRank();
5693 // At this point we know that we are taking numPackedDims outer
5694 // dimensions and pushing them all the way as the inner most dimensions.
5695 // What's left on the outer most dimensions is, in this order:
5696 // - the factor of the packed dimensions, then
5697 // - the untouched dimensions
5698 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5699 // if all the dimensions that bubble outerward are ones.
5700 // Therefore check that all the dimensions but the numPackedDims inner most
5701 // ones are ones.
5702 return llvm::all_of(
5703 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5704 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5705}
5706
5707bool PackOp::isLikePad() {
5708 auto packedTensorType =
5709 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5710 return isLikePadUnPad(*this, packedTensorType);
5711}
5712
5713OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5714 std::optional<Attribute> paddingValue;
5715 if (auto pad = adaptor.getPaddingValue())
5716 paddingValue = pad;
5717 if (OpFoldResult reshapedSource = reshapeConstantSource(
5718 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5719 getDestType(), paddingValue))
5720 return reshapedSource;
5721 return {};
5722}
5723
5724/// Folds a tensor.cast op into a consuming PackOp op if the
5725/// `tensor.cast` has source that is more static than the consuming op.
5726///
5727/// Example:
5728/// ```mlir
5729/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5730/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5731/// ```
5732///
5733/// folds into:
5734///
5735/// ```mlir
5736/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5737/// ```
5740
5741 LogicalResult matchAndRewrite(PackOp op,
5742 PatternRewriter &rewriter) const override {
5744 return failure();
5745
5746 SmallVector<Type> newResultTypes(op->getResultTypes());
5747 SmallVector<Value> newOperands =
5749
5750 // Get the updated mixed-tile-sizes attribute.
5751 SmallVector<OpFoldResult> newMixedTileSizes =
5752 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5753
5754 // Clone op.
5755 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5756 // this point. However, in practice, we use them for things that we'd like
5757 // to preserve. Implement a better abstraction.
5758 PackOp newOp =
5759 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5760 op.getInnerDimsPos(), newMixedTileSizes,
5761 op.getPaddingValue(), op.getOuterDimsPerm());
5762 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5763
5764 // Replace op.
5765 Value oldResult = op.getResult();
5766 Value newResult = newOp.getResult();
5768 (newResult.getType() != oldResult.getType())
5769 ? tensor::CastOp::create(rewriter, op->getLoc(),
5770 oldResult.getType(), newResult)
5771 : newResult;
5772
5773 rewriter.replaceOp(op, {replacement});
5774
5775 return success();
5776 }
5777};
5778
5779//===----------------------------------------------------------------------===//
5780// UnPackOp
5781//===----------------------------------------------------------------------===//
5782
5783void UnPackOp::getAsmResultNames(
5784 function_ref<void(Value, StringRef)> setNameFn) {
5785 setNameFn(getResult(), "unpack");
5786}
5787
5788LogicalResult
5789UnPackOp::reifyResultShapes(OpBuilder &builder,
5790 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5791 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5792}
5793
5794DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5795 return getDimAndTileMappingImpl(*this);
5796}
5797
5798SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5799 return getMixedTilesImpl(*this);
5800}
5801
5802SmallVector<int64_t> UnPackOp::getStaticTiles() {
5803 return getStaticTilesImpl(*this);
5804}
5805
5806ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5807 ShapedType destType = getDestType();
5808 int64_t destRank = destType.getRank();
5809 return getSourceType().getShape().take_front(destRank);
5810}
5811
5812SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5813 auto innerDimsPos = getInnerDimsPos();
5814 SmallVector<int64_t> outerDims(getAllOuterDims());
5816
5817 // Recover the original order of the outer dims.
5818 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5819 invertPermutationVector(outerDimPermInv);
5820 if (!outerDimPermInv.empty())
5821 applyPermutationToVector(outerDims, outerDimPermInv);
5822
5823 // Collect the outer dims corresponding to the tilled inner dims.
5824 for (auto index : innerDimsPos)
5825 res.push_back(outerDims[index]);
5826
5827 return res;
5828}
5829
5830LogicalResult UnPackOp::verify() {
5831 return commonVerifierPackAndUnPackOp(*this);
5832}
5833
5834Speculation::Speculatability UnPackOp::getSpeculatability() {
5835 // See PackOp::getSpeculatability.
5838
5840}
5841
5842void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5843 Value dest, ArrayRef<int64_t> innerDimsPos,
5844 ArrayRef<OpFoldResult> innerTiles,
5845 ArrayRef<int64_t> outerDimsPerm) {
5846 assert(innerDimsPos.size() == innerTiles.size() &&
5847 "number of tile sizes specified must match the specified number of "
5848 "original dimensions to be tiled");
5849 SmallVector<int64_t> staticTileSizes;
5850 SmallVector<Value> dynamicTileSizes;
5851 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5852 build(builder, state, dest.getType(), source, dest,
5853 outerDimsPerm.empty() ? nullptr
5854 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5855 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5856 builder.getDenseI64ArrayAttr(staticTileSizes));
5857}
5858
5859Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5860 Value source,
5861 ArrayRef<OpFoldResult> innerTileSizes,
5862 ArrayRef<int64_t> innerDimsPos,
5863 ArrayRef<int64_t> outerDimsPerm) {
5864 AffineExpr sym0, sym1;
5865 bindSymbols(b.getContext(), sym0, sym1);
5866 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5867 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5868 };
5869
5870 SmallVector<OpFoldResult> mixedSizes;
5871 auto srcType = llvm::cast<RankedTensorType>(source.getType());
5872 for (auto i :
5873 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5874 if (srcType.isDynamicDim(i))
5875 mixedSizes.push_back(
5876 tensor::DimOp::create(b, loc, source, i).getResult());
5877 else
5878 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5879 }
5880 if (!outerDimsPerm.empty()) {
5882 mixedSizes, invertPermutationVector(outerDimsPerm));
5883 }
5884
5885 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5886 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5887
5888 auto elemType = srcType.getElementType();
5889 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5890}
5891
5892UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5893 Value transposedSource,
5894 ArrayRef<int64_t> innerPermutation,
5895 ArrayRef<int64_t> outerPermutation) {
5896 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5897 *this, innerPermutation, outerPermutation);
5898 return UnPackOp::create(b, loc, transposedSource, getDest(),
5899 metadata.innerDimsPos, metadata.innerTiles,
5900 metadata.outerDimsPerm);
5901}
5902
5903/// Returns true if the `srcShape` or `destShape` is different from the one in
5904/// `op` and populates each with the inferred static shape.
5905static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5906 SmallVectorImpl<int64_t> &destShape) {
5907 bool changeNeeded = false;
5908 srcShape.assign(op.getSourceType().getShape().begin(),
5909 op.getSourceType().getShape().end());
5910 destShape.assign(op.getDestType().getShape().begin(),
5911 op.getDestType().getShape().end());
5912 llvm::SmallSetVector<int64_t, 4> innerDims;
5913 innerDims.insert_range(op.getInnerDimsPos());
5914 SmallVector<int64_t> inverseOuterDimsPerm;
5915 if (!op.getOuterDimsPerm().empty())
5916 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5917 int destRank = op.getDestRank();
5918 for (auto i : llvm::seq<int64_t>(0, destRank)) {
5919 if (innerDims.contains(i))
5920 continue;
5921 int64_t srcPos = i;
5922 int64_t destPos = i;
5923 if (!inverseOuterDimsPerm.empty())
5924 srcPos = inverseOuterDimsPerm[destPos];
5925 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5926 ShapedType::isDynamic(destShape[destPos])) {
5927 continue;
5928 }
5929 int64_t size = srcShape[srcPos];
5930 if (ShapedType::isDynamic(size))
5931 size = destShape[destPos];
5932 srcShape[srcPos] = size;
5933 destShape[destPos] = size;
5934 changeNeeded = true;
5935 }
5936 return changeNeeded;
5937}
5938
5939LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5940 PatternRewriter &rewriter) {
5941 /// unpack(pack(x)) -> x
5942 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5943 if (packOp.getSourceType() != unPackOp.getDestType())
5944 return failure();
5945 if (packOp.getPaddingValue() ||
5946 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5947 !haveSameTiles(packOp, unPackOp))
5948 return failure();
5949 rewriter.replaceOp(unPackOp, packOp.getSource());
5950 return success();
5951 }
5952 /// unpack(destinationStyleOp(x)) -> unpack(x)
5953 if (auto dstStyleOp =
5954 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5955 auto destValue = cast<OpResult>(unPackOp.getDest());
5956 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5957 rewriter.modifyOpInPlace(unPackOp,
5958 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5959 return success();
5960 }
5961 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5962 if (unPackOp->hasOneUse()) {
5963 auto extractSliceUser =
5964 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5965 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5966 OpBuilder::InsertionGuard g(rewriter);
5967 rewriter.setInsertionPoint(unPackOp);
5968 auto newDest = tensor::ExtractSliceOp::create(
5969 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5970 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5971 extractSliceUser.getMixedStrides());
5972 rewriter.modifyOpInPlace(unPackOp, [&]() {
5973 unPackOp.setDpsInitOperand(0, newDest);
5974 unPackOp.getResult().setType(newDest.getType());
5975 });
5976 rewriter.replaceOp(extractSliceUser, unPackOp);
5977 return success();
5978 }
5979 }
5980
5981 // Insert tensor.cast ops if static shape inference is available..
5982 SmallVector<int64_t> srcShape, destShape;
5983 if (inferStaticShape(unPackOp, srcShape, destShape)) {
5984 Location loc = unPackOp.getLoc();
5985 Value source = unPackOp.getSource();
5986 if (srcShape != unPackOp.getSourceType().getShape()) {
5987 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5988 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5989 unPackOp.getSource());
5990 }
5991 Value dest = unPackOp.getDest();
5992 if (destShape != unPackOp.getDestType().getShape()) {
5993 auto newDestType = unPackOp.getDestType().clone(destShape);
5994 dest = tensor::CastOp::create(rewriter, loc, newDestType,
5995 unPackOp.getDest());
5996 }
5997 Value newOp = UnPackOp::create(
5998 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5999 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6000 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6001 unPackOp, unPackOp.getResult().getType(), newOp);
6002 return success();
6003 }
6004
6005 return failure();
6006}
6007
6008bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6009 // Rank-reduced folding is not supported.
6010 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6011 return false;
6012 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6013 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6014 return false;
6015 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6016 SmallVector<int64_t> outerShapeWithoutTranspose =
6018 for (auto [pos, tileSize] :
6019 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6020 if (unpackedTypeAfterFold.isDynamicDim(pos))
6021 return false;
6022 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6023 return false;
6024 if (ShapedType::isDynamic(tileSize))
6025 return false;
6026 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6027 unpackedTypeAfterFold.getDimSize(pos);
6028 if (paddingSize >= tileSize)
6029 return false;
6030 }
6031 return true;
6032}
6033
6034bool UnPackOp::isLikeUnPad() {
6035 RankedTensorType packedTensorType = getSourceType();
6036 return isLikePadUnPad(*this, packedTensorType);
6037}
6038
6039OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
6040 if (OpFoldResult reshapedSource = reshapeConstantSource(
6041 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6042 getResult().getType()))
6043 return reshapedSource;
6044 return {};
6045}
6046
6047/// Folds a tensor.cast op into a consuming UnPackOp op if the
6048/// `tensor.cast` has source that is more static than the consuming op.
6049///
6050/// Example:
6051/// ```mlir
6052/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6053/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6054/// ```
6055///
6056/// folds into:
6057///
6058/// ```mlir
6059/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6060/// ```
6061struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6062 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6063
6064 LogicalResult matchAndRewrite(UnPackOp op,
6065 PatternRewriter &rewriter) const override {
6067 return failure();
6068
6069 SmallVector<Type> newResultTypes(op->getResultTypes());
6070 SmallVector<Value> newOperands =
6072 Value sourceTensor = newOperands[0];
6073
6074 // Get the updated mixed-tile-sizes attribute.
6076 rewriter, sourceTensor.getType(), op.getMixedTiles());
6077
6078 // Clone op.
6079 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6080 // this point. However, in practice, we use them for things that we'd like
6081 // to preserve. Implement a better abstraction.
6082 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6083 newOperands[1], op.getInnerDimsPos(),
6084 newMixedTileSizes, op.getOuterDimsPerm());
6085 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6086
6087 // Replace op.
6088 Value oldResult = op.getResult();
6089 Value newResult = newOp.getResult();
6091 (newResult.getType() != oldResult.getType())
6092 ? tensor::CastOp::create(rewriter, op->getLoc(),
6093 oldResult.getType(), newResult)
6094 : newResult;
6095
6096 rewriter.replaceOp(op, {replacement});
6097
6098 return success();
6099 }
6100};
6101
6102//===----------------------------------------------------------------------===//
6103// BatchReduceMatmulOp
6104//===----------------------------------------------------------------------===//
6105SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6107 utils::IteratorType::reduction, utils::IteratorType::parallel,
6108 utils::IteratorType::parallel, utils::IteratorType::reduction};
6109}
6110
6112BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6113 AffineExpr d0, d1, d2, d3;
6114 SmallVector<AffineMap> indexingMaps;
6115 bindDims(context, d0, d1, d2, d3);
6116 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6117 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6118 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6119 return indexingMaps;
6120}
6121
6122bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6123 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6124 if (!maps)
6125 return false;
6126 if (maps.size() != 3)
6127 return false;
6128 auto positions = getAffineResultPositions(maps);
6129 if (failed(positions))
6130 return false;
6131 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6132 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6133 (*positions)[2] == SmallVector<int64_t>{1, 2};
6134}
6135unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6136
6137std::string BatchReduceMatmulOp::getLibraryCallName() {
6138 return generateLibraryCallName(getOperation());
6139}
6140
6141/// Check if the op has broadcast and/or transpose semantic. Returns true if
6142/// the user defined indexing maps are not equal to default map.
6143bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6144 SmallVector<AffineMap, 3> defaultMaps =
6145 getDefaultIndexingMaps(this->getContext());
6146 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6147 return defaultMaps != explicitMaps;
6148}
6149
6150/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6151/// broadcast map must include K dimension.
6152/// TODO: Strict inclusion of K dimension in the broadcast map is not
6153/// necessary for both input matrices simultaneously. We can relax this
6154/// condition to have K dimension for one input matrix map and infer the K
6155/// dimension for other input matrix map from the one already having K
6156/// dimension.
6157bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6158 bool isLHS) {
6159 assert(bcastMap.getNumResults() < 3 &&
6160 "Expected less than 3 result dim expr.");
6161 bool isValid = false;
6162 enum Indices { batchPos, mPos, nPos, kPos };
6163 if (bcastMap.getNumResults() == 1) {
6164 AffineExpr expr = bcastMap.getResult(0);
6165 isValid = expr.isFunctionOfDim(kPos);
6166 } else if (bcastMap.getNumResults() == 2) {
6167 AffineExpr expr0 = bcastMap.getResult(0);
6168 AffineExpr expr1 = bcastMap.getResult(1);
6169 isValid =
6170 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6171 expr0.isFunctionOfDim(mPos)) &&
6172 expr1.isFunctionOfDim(kPos))
6173 : ((expr0.isFunctionOfDim(batchPos) &&
6174 expr1.isFunctionOfDim(kPos)) ||
6175 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6176 }
6177 return isValid;
6178}
6179
6180void BatchReduceMatmulOp::regionBuilder(
6183 if (emitError && block.getNumArguments() != 3) {
6184 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6185 << block.getNumArguments();
6186 return;
6187 }
6188 assert(block.getNumArguments() == 3 &&
6189 "BatchReduceMatmulOp regionBuilder expects 3 args");
6190 RegionBuilderHelper helper(b, block);
6191 SmallVector<Value> yields;
6192
6193 auto toType = block.getArgument(2).getType();
6194 Value castValA =
6195 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6196 Value castValB =
6197 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6198 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6199 Value addVal =
6200 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6201 yields.push_back(addVal);
6202 helper.yieldOutputs(yields);
6203}
6204
6205ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6207 SmallVector<Attribute, 3> indexingMapsAttr;
6208 Attribute mapAttr;
6209 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6210 if (parser.parseEqual())
6211 return failure();
6212 if (parser.parseLSquare())
6213 return failure();
6214
6215 do {
6216 if (parser.parseAttribute(mapAttr))
6217 return failure();
6218 if (!isa<AffineMapAttr>(mapAttr)) {
6219 return parser.emitError(parser.getCurrentLocation(),
6220 "expected affine map attribute");
6221 }
6222 indexingMapsAttr.push_back(mapAttr);
6223
6224 if (parser.parseOptionalComma())
6225 break;
6226 } while (true);
6227
6228 if (parser.parseRSquare())
6229 return failure();
6230 }
6231 // Initialize indexingMaps, if not supplied explicitly.
6232 if (indexingMapsAttr.empty()) {
6233 indexingMapsAttr = llvm::map_to_vector(
6234 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6235 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6236 }
6237 result.addAttribute("indexing_maps",
6238 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6239 return ::parseNamedStructuredOp(parser, result,
6240 BatchReduceMatmulOp::getNumRegionArgs(),
6241 BatchReduceMatmulOp::getRegionBuilder());
6242}
6243
6244void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6245 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6246 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6247 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6248
6249 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6250 p << " indexing_maps = [";
6251 llvm::interleaveComma(getIndexingMaps(), p,
6252 [&](Attribute attr) { p.printAttribute(attr); });
6253 p << "]";
6254 }
6255
6256 SmallVector<StringRef, 3> elidedAttrs = {
6257 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6258 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6259 elidedAttrs);
6260}
6261
6262/// Verify the user defined indexing maps.
6263LogicalResult BatchReduceMatmulOp::verify() {
6264 // Verification of pure batch_reduce_matmul is handled by
6265 // verifyStructuredOpInterface().
6266 if (!hasUserDefinedMaps())
6267 return success();
6268
6269 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6271 return failure();
6272 }
6273 return success();
6274}
6275LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6277 return memref::foldMemRefCast(*this);
6278}
6279void BatchReduceMatmulOp::getEffects(
6281 &effects) {
6282 if (hasPureTensorSemantics())
6283 return;
6284 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6285}
6286
6287Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6288 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6289}
6290
6291} // namespace linalg
6292} // namespace mlir
6293
6294//===----------------------------------------------------------------------===//
6295// LinalgDialect
6296//===----------------------------------------------------------------------===//
6297
6298void LinalgDialect::getCanonicalizationPatterns(
6299 RewritePatternSet &results) const {
6300 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6301 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6302}
6303
6304Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6305 Attribute value, Type type,
6306 Location loc) {
6307 return arith::ConstantOp::materialize(builder, value, type, loc);
6308}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:57
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
result_iterator result_begin()
Definition Operation.h:413
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_iterator result_end()
Definition Operation.h:414
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override