14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Support/InterleavedRange.h"
30 elements.push_back(value);
35 "parsing values in integer list attribute")) {
39 auto i32Type = IntegerType::get(parser.
getContext(), 32);
48 printer << llvm::interleaved_array(
49 llvm::map_range(attr.getValues<APInt>(),
50 [](
const APInt &a) { return a.getSExtValue(); }));
60 return attr.getValues<APInt>()[idx].getSExtValue();
63LogicalResult verifyPool2DOutputDim(Operation *op,
int64_t inputSize,
66 int64_t padAfter, StringRef dimName,
67 StringRef dimAxis, StringRef padBeforeName,
68 StringRef padAfterName) {
69 if (ShapedType::isDynamic(inputSize))
72 const int64_t numerator = inputSize + padBefore + padAfter - kernelSize;
73 if (numerator % strideSize != 0)
74 return op->emitOpError(
"expected input_")
75 << dimName <<
" + pad_" << padBeforeName <<
" + pad_" << padAfterName
76 <<
" - kernel_" << dimAxis <<
" to be wholly divisible by stride_"
77 << dimAxis <<
", got (" << inputSize <<
" + " << padBefore <<
" + "
78 << padAfter <<
" - " << kernelSize <<
") / " << strideSize;
80 const int64_t calculatedOutput = numerator / strideSize + 1;
81 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
82 return op->emitOpError(
"failed to verify that shapes of input and output "
83 "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with "
84 "OH = ((IH + pad_top + pad_bottom - kernel_y) / "
85 "stride_y) + 1 and OW = ((IW + pad_left + "
86 "pad_right - kernel_x) / stride_x) + 1");
91LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
92 DenseIntElementsAttr stride,
96 if (!inputType.hasRank() || !outputType.hasRank())
99 if (
failed(verifyPool2DOutputDim(
100 op, inputType.getDimSize(1), outputType.getDimSize(1),
101 getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0),
102 getIntValue(pad, 1),
"height",
"y",
"top",
"bottom")))
105 if (
failed(verifyPool2DOutputDim(
106 op, inputType.getDimSize(2), outputType.getDimSize(2),
107 getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2),
108 getIntValue(pad, 3),
"width",
"x",
"left",
"right")))
114LogicalResult verifyConvolutionOutputDim(int64_t inputSize, int64_t kernelSize,
115 int64_t outputSize, int64_t padBefore,
116 int64_t padAfter, int64_t strideSize,
117 int64_t dilationSize) {
118 if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
121 const int64_t numerator =
122 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilationSize;
123 if (numerator % strideSize != 0)
126 const int64_t calculatedOutput = numerator / strideSize + 1;
127 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
134verifyTransposeConvolutionOutputDim(int64_t inputSize, int64_t kernelSize,
135 int64_t outputSize, int64_t padBefore,
136 int64_t padAfter, int64_t strideSize) {
137 if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
140 const int64_t calculatedOutput =
141 (inputSize - 1) * strideSize + padBefore + padAfter + kernelSize;
142 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
148LogicalResult verifyConv2DOutputShape(Operation *op, DenseIntElementsAttr pad,
149 DenseIntElementsAttr stride,
150 DenseIntElementsAttr dilation,
154 constexpr StringLiteral errorMessage =
155 "failed to verify that shapes of input, weight, and output must satisfy "
156 "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
157 "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
158 "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
159 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
162 if (
failed(verifyConvolutionOutputDim(
163 inputType.getDimSize(1), weightType.getDimSize(1),
164 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
165 getIntValue(stride, 0), getIntValue(dilation, 0))))
166 return op->emitOpError(errorMessage);
168 if (
failed(verifyConvolutionOutputDim(
169 inputType.getDimSize(2), weightType.getDimSize(2),
170 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
171 getIntValue(stride, 1), getIntValue(dilation, 1))))
172 return op->emitOpError(errorMessage);
177LogicalResult verifyConv3DOutputShape(Operation *op, DenseIntElementsAttr pad,
178 DenseIntElementsAttr stride,
179 DenseIntElementsAttr dilation,
183 constexpr StringLiteral errorMessage =
184 "failed to verify that shapes of input, weight, and output must satisfy "
185 "[N,ID,IH,IW,*], [*,KD,KH,KW,*], [N,OD,OH,OW,*], with OD = ((ID - 1 + "
186 "pad_front + pad_back - (KD - 1) * dilation_d) / stride_d) + 1, OH = "
187 "((IH - 1 + pad_top + pad_bottom - (KH - 1) * dilation_y) / stride_y) "
188 "+ 1 and OW = ((IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x) "
190 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
193 if (
failed(verifyConvolutionOutputDim(
194 inputType.getDimSize(1), weightType.getDimSize(1),
195 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
196 getIntValue(stride, 0), getIntValue(dilation, 0))))
197 return op->emitOpError(errorMessage);
199 if (
failed(verifyConvolutionOutputDim(
200 inputType.getDimSize(2), weightType.getDimSize(2),
201 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
202 getIntValue(stride, 1), getIntValue(dilation, 1))))
203 return op->emitOpError(errorMessage);
205 if (
failed(verifyConvolutionOutputDim(
206 inputType.getDimSize(3), weightType.getDimSize(3),
207 outputType.getDimSize(3), getIntValue(pad, 4), getIntValue(pad, 5),
208 getIntValue(stride, 2), getIntValue(dilation, 2))))
209 return op->emitOpError(errorMessage);
214LogicalResult verifyDepthwiseConv2DOutputShape(
215 Operation *op, DenseIntElementsAttr pad, DenseIntElementsAttr stride,
218 constexpr StringLiteral errorMessage =
219 "failed to verify that shapes of input, weight, and output must satisfy "
220 "[N,IH,IW,*], [KH,KW,*,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
221 "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
222 "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
223 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
226 if (
failed(verifyConvolutionOutputDim(
227 inputType.getDimSize(1), weightType.getDimSize(0),
228 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
229 getIntValue(stride, 0), getIntValue(dilation, 0))))
230 return op->emitOpError(errorMessage);
232 if (
failed(verifyConvolutionOutputDim(
233 inputType.getDimSize(2), weightType.getDimSize(1),
234 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
235 getIntValue(stride, 1), getIntValue(dilation, 1))))
236 return op->emitOpError(errorMessage);
241LogicalResult verifyTransposeConv2DOutputShape(Operation *op,
242 DenseIntElementsAttr outPad,
243 DenseIntElementsAttr stride,
247 constexpr StringLiteral errorMessage =
248 "failed to verify that shapes of input, weight, and output must satisfy "
249 "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = (IH - 1) * stride_y + "
250 "out_pad_top + out_pad_bottom + KH and OW = (IW - 1) * stride_x + "
251 "out_pad_left + out_pad_right + KW";
252 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
255 const int64_t kernelHeight = weightType.getDimSize(1);
256 if (ShapedType::isStatic(kernelHeight) &&
257 (getIntValue(outPad, 0) <= -kernelHeight ||
258 getIntValue(outPad, 1) <= -kernelHeight))
259 return op->emitOpError(
"expected out_pad_top and out_pad_bottom to be > "
262 const int64_t kernelWidth = weightType.getDimSize(2);
263 if (ShapedType::isStatic(kernelWidth) &&
264 (getIntValue(outPad, 2) <= -kernelWidth ||
265 getIntValue(outPad, 3) <= -kernelWidth))
266 return op->emitOpError(
"expected out_pad_left and out_pad_right to be > "
269 if (
failed(verifyTransposeConvolutionOutputDim(
270 inputType.getDimSize(1), kernelHeight, outputType.getDimSize(1),
271 getIntValue(outPad, 0), getIntValue(outPad, 1),
272 getIntValue(stride, 0))))
273 return op->emitOpError(errorMessage);
275 if (
failed(verifyTransposeConvolutionOutputDim(
276 inputType.getDimSize(2), kernelWidth, outputType.getDimSize(2),
277 getIntValue(outPad, 2), getIntValue(outPad, 3),
278 getIntValue(stride, 1))))
279 return op->emitOpError(errorMessage);
284LogicalResult verifyConcatOutputShape(Operation *op,
TypeRange inputTypes,
286 constexpr StringLiteral errorMessage =
287 "failed to verify that shape of output must match the concatenation of "
289 if (!outputType.hasRank())
292 if (llvm::any_of(inputTypes, [](Type type) {
293 return !cast<TensorArmType>(type).hasRank();
297 for (int64_t dim = 0, rank = outputType.getRank(); dim < rank; ++dim) {
298 int64_t outputDim = outputType.getDimSize(dim);
299 if (ShapedType::isDynamic(outputDim))
303 for (Type type : inputTypes) {
304 int64_t inputDim = cast<TensorArmType>(type).getDimSize(dim);
305 if (ShapedType::isStatic(inputDim) && inputDim != outputDim)
306 return op->emitOpError(errorMessage);
311 int64_t concatDim = 0;
312 for (Type type : inputTypes) {
313 int64_t inputDim = cast<TensorArmType>(type).getDimSize(dim);
314 if (ShapedType::isDynamic(inputDim)) {
315 concatDim = ShapedType::kDynamic;
318 concatDim += inputDim;
321 if (ShapedType::isStatic(concatDim) && concatDim != outputDim)
322 return op->emitOpError(errorMessage);
330LogicalResult TosaAvgPool2DOp::verify() {
331 return verifyPool2DOp(getOperation(),
getKernel(), getStride(), getPad(),
332 getInputType(), getResultType());
335LogicalResult TosaConv2DOp::verify() {
336 return verifyConv2DOutputShape(getOperation(), getPad(), getStride(),
337 getDilation(), getInputType(), getWeightType(),
341LogicalResult TosaConv3DOp::verify() {
342 return verifyConv3DOutputShape(getOperation(), getPad(), getStride(),
343 getDilation(), getInputType(), getWeightType(),
347LogicalResult TosaDepthwiseConv2DOp::verify() {
348 return verifyDepthwiseConv2DOutputShape(getOperation(), getPad(), getStride(),
349 getDilation(), getInputType(),
350 getWeightType(), getResultType());
353LogicalResult TosaMaxPool2DOp::verify() {
354 return verifyPool2DOp(getOperation(),
getKernel(), getStride(), getPad(),
355 getInputType(), getResultType());
358LogicalResult TosaTransposeConv2DOp::verify() {
359 return verifyTransposeConv2DOutputShape(getOperation(), getOutPad(),
360 getStride(), getInputType(),
361 getWeightType(), getResultType());
364LogicalResult TosaConcatOp::verify() {
365 return verifyConcatOutputShape(getOperation(), getInput1Types(),
366 getResultType(), getAxis());
369LogicalResult TosaSelectOp::verify() {
375 if (llvm::any_of(ArrayRef<TensorArmType>{condType, trueValType, falseValType,
380 ArrayRef<int64_t> condShape = condType.getShape();
381 ArrayRef<int64_t> trueValShape = trueValType.getShape();
382 ArrayRef<int64_t> falseValShape = falseValType.getShape();
383 ArrayRef<int64_t> resultShape = resultType.getShape();
385 if (!llvm::all_equal({condShape.size(), trueValShape.size(),
386 falseValShape.size(), resultShape.size()})) {
394 llvm::zip_equal(condShape, trueValShape, falseValShape, resultShape)) {
395 auto [condDim, trueValDim, falseValDim, resultDim] = dims;
398 ArrayRef<int64_t>{condDim, trueValDim, falseValDim, resultDim},
399 [](int64_t dim) {
return ShapedType::isDynamic(dim); })) {
403 auto isPairBroadcastable = [](int64_t
lhs, int64_t
rhs) {
407 if (!isPairBroadcastable(condDim, trueValDim) ||
408 !isPairBroadcastable(condDim, falseValDim) ||
409 !isPairBroadcastable(trueValDim, falseValDim)) {
411 "failed to verify that the shape of inputs: condition, "
412 "true_value, and false_value are compatible for "
416 int64_t bradcastedInputDim =
417 std::max(condDim, std::max(trueValDim, falseValDim));
418 if (bradcastedInputDim != resultDim) {
420 "failed to verify that the broadcast shape of inputs: condition, "
421 "true_value, and false_value is equal to "
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, DenseIntElementsAttr attr)
ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser, DenseIntElementsAttr &attr)